mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 13:22:41 +02:00
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
843 lines
30 KiB
Python
843 lines
30 KiB
Python
"""
|
|
Image Generation routes:
|
|
- CRUD for ImageGenerationConfig (user-created image model configs)
|
|
- Global image gen configs endpoint (from YAML)
|
|
- Image generation execution (calls litellm.aimage_generation())
|
|
- CRUD for ImageGeneration records (results)
|
|
- Image serving endpoint (serves b64_json images from DB, protected by signed tokens)
|
|
"""
|
|
|
|
import base64
|
|
import logging
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from fastapi.responses import Response
|
|
from litellm import aimage_generation
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import config
|
|
from app.db import (
|
|
ImageGeneration,
|
|
ImageGenerationConfig,
|
|
Permission,
|
|
SearchSpace,
|
|
SearchSpaceMembership,
|
|
User,
|
|
get_async_session,
|
|
)
|
|
from app.schemas import (
|
|
GlobalImageGenConfigRead,
|
|
ImageGenerationConfigCreate,
|
|
ImageGenerationConfigRead,
|
|
ImageGenerationConfigUpdate,
|
|
ImageGenerationCreate,
|
|
ImageGenerationListRead,
|
|
ImageGenerationRead,
|
|
)
|
|
from app.services.billable_calls import (
|
|
DEFAULT_IMAGE_RESERVE_MICROS,
|
|
QuotaInsufficientError,
|
|
billable_call,
|
|
)
|
|
from app.services.image_gen_router_service import (
|
|
IMAGE_GEN_AUTO_MODE_ID,
|
|
ImageGenRouterService,
|
|
is_image_gen_auto_mode,
|
|
)
|
|
from app.services.provider_api_base import resolve_api_base
|
|
from app.users import current_active_user
|
|
from app.utils.rbac import check_permission
|
|
from app.utils.signed_image_urls import verify_image_token
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Provider mapping for building litellm model strings.
|
|
# Only includes providers that support image generation.
|
|
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
|
_PROVIDER_MAP = {
|
|
"OPENAI": "openai",
|
|
"AZURE_OPENAI": "azure",
|
|
"GOOGLE": "gemini", # Google AI Studio
|
|
"VERTEX_AI": "vertex_ai",
|
|
"BEDROCK": "bedrock", # AWS Bedrock
|
|
"RECRAFT": "recraft",
|
|
"OPENROUTER": "openrouter",
|
|
"XINFERENCE": "xinference",
|
|
"NSCALE": "nscale",
|
|
}
|
|
|
|
|
|
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|
"""Get a global image generation configuration by ID (negative IDs)."""
|
|
if config_id == IMAGE_GEN_AUTO_MODE_ID:
|
|
return {
|
|
"id": IMAGE_GEN_AUTO_MODE_ID,
|
|
"name": "Auto (Fastest)",
|
|
"provider": "AUTO",
|
|
"model_name": "auto",
|
|
"is_auto_mode": True,
|
|
}
|
|
if config_id > 0:
|
|
return None
|
|
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
|
if cfg.get("id") == config_id:
|
|
return cfg
|
|
return None
|
|
|
|
|
|
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
|
"""Resolve the LiteLLM provider prefix used in model strings."""
|
|
if custom_provider:
|
|
return custom_provider
|
|
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
|
|
|
|
|
def _build_model_string(
|
|
provider: str, model_name: str, custom_provider: str | None
|
|
) -> str:
|
|
"""Build a litellm model string from provider + model_name."""
|
|
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
|
|
|
|
|
async def _resolve_billing_for_image_gen(
|
|
session: AsyncSession,
|
|
config_id: int | None,
|
|
search_space: SearchSpace,
|
|
) -> tuple[str, str, int]:
|
|
"""Resolve ``(billing_tier, base_model, reserve_micros)`` for a request.
|
|
|
|
The resolution mirrors ``_execute_image_generation``'s lookup tree but
|
|
only extracts the fields needed for billing — we do this *before*
|
|
``billable_call`` so the reservation is correctly sized for the
|
|
config that will actually run, and so we don't open an
|
|
``ImageGeneration`` row for a request that's about to 402.
|
|
|
|
User-owned (positive ID) BYOK configs are always free — they cost
|
|
the user nothing on our side. Auto mode currently treats as free
|
|
because the underlying router can dispatch to either premium or
|
|
free YAML configs and we don't surface the resolved deployment up
|
|
here yet. Bringing Auto under premium billing would require
|
|
threading the chosen deployment back from ``ImageGenRouterService``.
|
|
"""
|
|
resolved_id = config_id
|
|
if resolved_id is None:
|
|
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
|
|
|
if is_image_gen_auto_mode(resolved_id):
|
|
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
|
|
|
if resolved_id < 0:
|
|
cfg = _get_global_image_gen_config(resolved_id) or {}
|
|
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
|
base_model = _build_model_string(
|
|
cfg.get("provider", ""),
|
|
cfg.get("model_name", ""),
|
|
cfg.get("custom_provider"),
|
|
)
|
|
reserve_micros = int(
|
|
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
|
)
|
|
return (billing_tier, base_model, reserve_micros)
|
|
|
|
# Positive ID = user-owned BYOK image-gen config — always free.
|
|
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
|
|
|
|
|
|
async def _execute_image_generation(
|
|
session: AsyncSession,
|
|
image_gen: ImageGeneration,
|
|
search_space: SearchSpace,
|
|
) -> None:
|
|
"""
|
|
Call litellm.aimage_generation() with the appropriate config.
|
|
|
|
Resolution order:
|
|
1. Explicit image_generation_config_id on the request
|
|
2. Search space's image_generation_config_id preference
|
|
3. Falls back to Auto mode if available
|
|
"""
|
|
config_id = image_gen.image_generation_config_id
|
|
if config_id is None:
|
|
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
|
image_gen.image_generation_config_id = config_id
|
|
|
|
# Build kwargs
|
|
gen_kwargs = {}
|
|
if image_gen.n is not None:
|
|
gen_kwargs["n"] = image_gen.n
|
|
if image_gen.quality is not None:
|
|
gen_kwargs["quality"] = image_gen.quality
|
|
if image_gen.size is not None:
|
|
gen_kwargs["size"] = image_gen.size
|
|
if image_gen.style is not None:
|
|
gen_kwargs["style"] = image_gen.style
|
|
if image_gen.response_format is not None:
|
|
gen_kwargs["response_format"] = image_gen.response_format
|
|
|
|
if is_image_gen_auto_mode(config_id):
|
|
if not ImageGenRouterService.is_initialized():
|
|
raise ValueError(
|
|
"Auto mode requested but Image Generation Router not initialized. "
|
|
"Ensure global_llm_config.yaml has global_image_generation_configs."
|
|
)
|
|
response = await ImageGenRouterService.aimage_generation(
|
|
prompt=image_gen.prompt, model="auto", **gen_kwargs
|
|
)
|
|
elif config_id < 0:
|
|
# Global config from YAML
|
|
cfg = _get_global_image_gen_config(config_id)
|
|
if not cfg:
|
|
raise ValueError(f"Global image generation config {config_id} not found")
|
|
|
|
provider_prefix = _resolve_provider_prefix(
|
|
cfg.get("provider", ""), cfg.get("custom_provider")
|
|
)
|
|
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
|
gen_kwargs["api_key"] = cfg.get("api_key")
|
|
api_base = resolve_api_base(
|
|
provider=cfg.get("provider"),
|
|
provider_prefix=provider_prefix,
|
|
config_api_base=cfg.get("api_base"),
|
|
)
|
|
if api_base:
|
|
gen_kwargs["api_base"] = api_base
|
|
if cfg.get("api_version"):
|
|
gen_kwargs["api_version"] = cfg["api_version"]
|
|
if cfg.get("litellm_params"):
|
|
gen_kwargs.update(cfg["litellm_params"])
|
|
|
|
# User model override
|
|
if image_gen.model:
|
|
model_string = image_gen.model
|
|
|
|
response = await aimage_generation(
|
|
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
|
)
|
|
else:
|
|
# Positive ID = DB ImageGenerationConfig
|
|
result = await session.execute(
|
|
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
|
)
|
|
db_cfg = result.scalars().first()
|
|
if not db_cfg:
|
|
raise ValueError(f"Image generation config {config_id} not found")
|
|
|
|
provider_prefix = _resolve_provider_prefix(
|
|
db_cfg.provider.value, db_cfg.custom_provider
|
|
)
|
|
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
|
gen_kwargs["api_key"] = db_cfg.api_key
|
|
api_base = resolve_api_base(
|
|
provider=db_cfg.provider.value,
|
|
provider_prefix=provider_prefix,
|
|
config_api_base=db_cfg.api_base,
|
|
)
|
|
if api_base:
|
|
gen_kwargs["api_base"] = api_base
|
|
if db_cfg.api_version:
|
|
gen_kwargs["api_version"] = db_cfg.api_version
|
|
if db_cfg.litellm_params:
|
|
gen_kwargs.update(db_cfg.litellm_params)
|
|
|
|
# User model override
|
|
if image_gen.model:
|
|
model_string = image_gen.model
|
|
|
|
response = await aimage_generation(
|
|
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
|
)
|
|
|
|
# Store response
|
|
image_gen.response_data = (
|
|
response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
|
)
|
|
if not image_gen.model and hasattr(response, "_hidden_params"):
|
|
hidden = response._hidden_params
|
|
if isinstance(hidden, dict) and hidden.get("model"):
|
|
image_gen.model = hidden["model"]
|
|
|
|
|
|
# =============================================================================
|
|
# Global Image Generation Configs (from YAML)
|
|
# =============================================================================
|
|
|
|
|
|
@router.get(
|
|
"/global-image-generation-configs",
|
|
response_model=list[GlobalImageGenConfigRead],
|
|
)
|
|
async def get_global_image_gen_configs(
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Get all global image generation configs. API keys are hidden."""
|
|
try:
|
|
global_configs = config.GLOBAL_IMAGE_GEN_CONFIGS
|
|
safe_configs = []
|
|
|
|
if global_configs and len(global_configs) > 0:
|
|
safe_configs.append(
|
|
{
|
|
"id": 0,
|
|
"name": "Auto (Fastest)",
|
|
"description": "Automatically routes across available image generation providers.",
|
|
"provider": "AUTO",
|
|
"custom_provider": None,
|
|
"model_name": "auto",
|
|
"api_base": None,
|
|
"api_version": None,
|
|
"litellm_params": {},
|
|
"is_global": True,
|
|
"is_auto_mode": True,
|
|
# Auto mode currently treated as free until per-deployment
|
|
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
|
"billing_tier": "free",
|
|
"is_premium": False,
|
|
}
|
|
)
|
|
|
|
for cfg in global_configs:
|
|
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
|
safe_configs.append(
|
|
{
|
|
"id": cfg.get("id"),
|
|
"name": cfg.get("name"),
|
|
"description": cfg.get("description"),
|
|
"provider": cfg.get("provider"),
|
|
"custom_provider": cfg.get("custom_provider"),
|
|
"model_name": cfg.get("model_name"),
|
|
"api_base": cfg.get("api_base") or None,
|
|
"api_version": cfg.get("api_version") or None,
|
|
"litellm_params": cfg.get("litellm_params", {}),
|
|
"is_global": True,
|
|
"billing_tier": billing_tier,
|
|
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
|
# selector's premium badge logic keys off the same
|
|
# field across chat / image / vision tabs.
|
|
"is_premium": billing_tier == "premium",
|
|
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
|
}
|
|
)
|
|
|
|
return safe_configs
|
|
except Exception as e:
|
|
logger.exception("Failed to fetch global image generation configs")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
|
) from e
|
|
|
|
|
|
# =============================================================================
|
|
# ImageGenerationConfig CRUD
|
|
# =============================================================================
|
|
|
|
|
|
@router.post("/image-generation-configs", response_model=ImageGenerationConfigRead)
|
|
async def create_image_gen_config(
|
|
config_data: ImageGenerationConfigCreate,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Create a new image generation config for a search space."""
|
|
try:
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
config_data.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_CREATE.value,
|
|
"You don't have permission to create image generation configs in this search space",
|
|
)
|
|
|
|
db_config = ImageGenerationConfig(**config_data.model_dump(), user_id=user.id)
|
|
session.add(db_config)
|
|
await session.commit()
|
|
await session.refresh(db_config)
|
|
return db_config
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.exception("Failed to create ImageGenerationConfig")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to create config: {e!s}"
|
|
) from e
|
|
|
|
|
|
@router.get("/image-generation-configs", response_model=list[ImageGenerationConfigRead])
|
|
async def list_image_gen_configs(
|
|
search_space_id: int,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""List image generation configs for a search space."""
|
|
try:
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
search_space_id,
|
|
Permission.IMAGE_GENERATIONS_READ.value,
|
|
"You don't have permission to view image generation configs in this search space",
|
|
)
|
|
|
|
result = await session.execute(
|
|
select(ImageGenerationConfig)
|
|
.filter(ImageGenerationConfig.search_space_id == search_space_id)
|
|
.order_by(ImageGenerationConfig.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Failed to list ImageGenerationConfigs")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to fetch configs: {e!s}"
|
|
) from e
|
|
|
|
|
|
@router.get(
|
|
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
|
)
|
|
async def get_image_gen_config(
|
|
config_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Get a specific image generation config by ID."""
|
|
try:
|
|
result = await session.execute(
|
|
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
|
)
|
|
db_config = result.scalars().first()
|
|
if not db_config:
|
|
raise HTTPException(status_code=404, detail="Config not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
db_config.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_READ.value,
|
|
"You don't have permission to view image generation configs in this search space",
|
|
)
|
|
return db_config
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Failed to get ImageGenerationConfig")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to fetch config: {e!s}"
|
|
) from e
|
|
|
|
|
|
@router.put(
|
|
"/image-generation-configs/{config_id}", response_model=ImageGenerationConfigRead
|
|
)
|
|
async def update_image_gen_config(
|
|
config_id: int,
|
|
update_data: ImageGenerationConfigUpdate,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Update an existing image generation config."""
|
|
try:
|
|
result = await session.execute(
|
|
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
|
)
|
|
db_config = result.scalars().first()
|
|
if not db_config:
|
|
raise HTTPException(status_code=404, detail="Config not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
db_config.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_CREATE.value,
|
|
"You don't have permission to update image generation configs in this search space",
|
|
)
|
|
|
|
for key, value in update_data.model_dump(exclude_unset=True).items():
|
|
setattr(db_config, key, value)
|
|
|
|
await session.commit()
|
|
await session.refresh(db_config)
|
|
return db_config
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.exception("Failed to update ImageGenerationConfig")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to update config: {e!s}"
|
|
) from e
|
|
|
|
|
|
@router.delete("/image-generation-configs/{config_id}", response_model=dict)
|
|
async def delete_image_gen_config(
|
|
config_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Delete an image generation config."""
|
|
try:
|
|
result = await session.execute(
|
|
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
|
)
|
|
db_config = result.scalars().first()
|
|
if not db_config:
|
|
raise HTTPException(status_code=404, detail="Config not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
db_config.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_DELETE.value,
|
|
"You don't have permission to delete image generation configs in this search space",
|
|
)
|
|
|
|
await session.delete(db_config)
|
|
await session.commit()
|
|
return {
|
|
"message": "Image generation config deleted successfully",
|
|
"id": config_id,
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.exception("Failed to delete ImageGenerationConfig")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to delete config: {e!s}"
|
|
) from e
|
|
|
|
|
|
# =============================================================================
|
|
# Image Generation Execution + Results CRUD
|
|
# =============================================================================
|
|
|
|
|
|
@router.post("/image-generations", response_model=ImageGenerationRead)
|
|
async def create_image_generation(
|
|
data: ImageGenerationCreate,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Create and execute an image generation request.
|
|
|
|
Premium configs are gated by the user's shared premium credit pool.
|
|
The flow is:
|
|
|
|
1. Permission check + load the search space (cheap, no provider call).
|
|
2. Resolve which config will run so we know its billing tier and the
|
|
worst-case reservation size *before* opening any DB rows.
|
|
3. Wrap the entire ImageGeneration row insert + provider call in
|
|
``billable_call``. If quota is denied, ``billable_call`` raises
|
|
``QuotaInsufficientError`` *before* we flush a row, which we
|
|
translate to HTTP 402 (no orphaned rows on the user's account,
|
|
no inserted error rows for "you ran out of credit").
|
|
4. On success, the actual ``response_cost`` flows through the
|
|
LiteLLM callback into the accumulator, and ``billable_call``
|
|
finalizes the debit at exit. Inner ``try/except`` still catches
|
|
provider errors and stores them on ``error_message`` (HTTP 200
|
|
with ``error_message`` set is preserved for failed-but-not-quota
|
|
scenarios — clients already know how to surface those).
|
|
"""
|
|
try:
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
data.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_CREATE.value,
|
|
"You don't have permission to create image generations in this search space",
|
|
)
|
|
|
|
result = await session.execute(
|
|
select(SearchSpace).filter(SearchSpace.id == data.search_space_id)
|
|
)
|
|
search_space = result.scalars().first()
|
|
if not search_space:
|
|
raise HTTPException(status_code=404, detail="Search space not found")
|
|
|
|
billing_tier, base_model, reserve_micros = await _resolve_billing_for_image_gen(
|
|
session, data.image_generation_config_id, search_space
|
|
)
|
|
|
|
# billable_call runs OUTSIDE the inner try/except so QuotaInsufficientError
|
|
# propagates to the outer ``except QuotaInsufficientError`` handler
|
|
# below as HTTP 402 — it is intentionally NOT swallowed into
|
|
# ``error_message`` because that would (1) imply a successful row
|
|
# exists when none does, and (2) return HTTP 200 to a client
|
|
# whose request was actively *denied* (issue K).
|
|
async with billable_call(
|
|
user_id=search_space.user_id,
|
|
search_space_id=data.search_space_id,
|
|
billing_tier=billing_tier,
|
|
base_model=base_model,
|
|
quota_reserve_micros_override=reserve_micros,
|
|
usage_type="image_generation",
|
|
call_details={"model": base_model, "prompt": data.prompt[:100]},
|
|
):
|
|
db_image_gen = ImageGeneration(
|
|
prompt=data.prompt,
|
|
model=data.model,
|
|
n=data.n,
|
|
quality=data.quality,
|
|
size=data.size,
|
|
style=data.style,
|
|
response_format=data.response_format,
|
|
image_generation_config_id=data.image_generation_config_id,
|
|
search_space_id=data.search_space_id,
|
|
created_by_id=user.id,
|
|
)
|
|
session.add(db_image_gen)
|
|
await session.flush()
|
|
|
|
try:
|
|
await _execute_image_generation(session, db_image_gen, search_space)
|
|
except Exception as e:
|
|
logger.exception("Image generation call failed")
|
|
db_image_gen.error_message = str(e)
|
|
|
|
await session.commit()
|
|
await session.refresh(db_image_gen)
|
|
return db_image_gen
|
|
|
|
except HTTPException:
|
|
raise
|
|
except QuotaInsufficientError as exc:
|
|
# The user's premium credit pool is empty. No DB row is created
|
|
# because ``billable_call`` denies before yielding (issue K).
|
|
await session.rollback()
|
|
raise HTTPException(
|
|
status_code=402,
|
|
detail={
|
|
"error_code": "premium_quota_exhausted",
|
|
"usage_type": exc.usage_type,
|
|
"used_micros": exc.used_micros,
|
|
"limit_micros": exc.limit_micros,
|
|
"remaining_micros": exc.remaining_micros,
|
|
"message": (
|
|
"Out of premium credits for image generation. "
|
|
"Purchase additional credits or switch to a free model."
|
|
),
|
|
},
|
|
) from exc
|
|
except SQLAlchemyError:
|
|
await session.rollback()
|
|
raise HTTPException(
|
|
status_code=500, detail="Database error during image generation"
|
|
) from None
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.exception("Failed to create image generation")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Image generation failed: {e!s}"
|
|
) from e
|
|
|
|
|
|
@router.get("/image-generations", response_model=list[ImageGenerationListRead])
|
|
async def list_image_generations(
|
|
search_space_id: int | None = None,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""List image generations."""
|
|
if skip < 0 or limit < 1:
|
|
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
|
if limit > 100:
|
|
limit = 100
|
|
|
|
try:
|
|
if search_space_id is not None:
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
search_space_id,
|
|
Permission.IMAGE_GENERATIONS_READ.value,
|
|
"You don't have permission to read image generations in this search space",
|
|
)
|
|
result = await session.execute(
|
|
select(ImageGeneration)
|
|
.filter(ImageGeneration.search_space_id == search_space_id)
|
|
.order_by(ImageGeneration.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
else:
|
|
result = await session.execute(
|
|
select(ImageGeneration)
|
|
.join(SearchSpace)
|
|
.join(SearchSpaceMembership)
|
|
.filter(SearchSpaceMembership.user_id == user.id)
|
|
.order_by(ImageGeneration.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
|
|
return [
|
|
ImageGenerationListRead.from_orm_with_count(img)
|
|
for img in result.scalars().all()
|
|
]
|
|
|
|
except HTTPException:
|
|
raise
|
|
except SQLAlchemyError:
|
|
raise HTTPException(
|
|
status_code=500, detail="Database error fetching image generations"
|
|
) from None
|
|
|
|
|
|
@router.get("/image-generations/{image_gen_id}", response_model=ImageGenerationRead)
|
|
async def get_image_generation(
|
|
image_gen_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Get a specific image generation by ID."""
|
|
try:
|
|
result = await session.execute(
|
|
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
|
)
|
|
image_gen = result.scalars().first()
|
|
if not image_gen:
|
|
raise HTTPException(status_code=404, detail="Image generation not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
image_gen.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_READ.value,
|
|
"You don't have permission to read image generations in this search space",
|
|
)
|
|
return image_gen
|
|
|
|
except HTTPException:
|
|
raise
|
|
except SQLAlchemyError:
|
|
raise HTTPException(
|
|
status_code=500, detail="Database error fetching image generation"
|
|
) from None
|
|
|
|
|
|
@router.delete("/image-generations/{image_gen_id}", response_model=dict)
|
|
async def delete_image_generation(
|
|
image_gen_id: int,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
user: User = Depends(current_active_user),
|
|
):
|
|
"""Delete an image generation record."""
|
|
try:
|
|
result = await session.execute(
|
|
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
|
)
|
|
db_image_gen = result.scalars().first()
|
|
if not db_image_gen:
|
|
raise HTTPException(status_code=404, detail="Image generation not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
db_image_gen.search_space_id,
|
|
Permission.IMAGE_GENERATIONS_DELETE.value,
|
|
"You don't have permission to delete image generations in this search space",
|
|
)
|
|
|
|
await session.delete(db_image_gen)
|
|
await session.commit()
|
|
return {"message": "Image generation deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except SQLAlchemyError:
|
|
await session.rollback()
|
|
raise HTTPException(
|
|
status_code=500, detail="Database error deleting image generation"
|
|
) from None
|
|
|
|
|
|
# =============================================================================
|
|
# Image Serving (serves generated images from DB, protected by signed tokens)
|
|
# =============================================================================
|
|
|
|
|
|
@router.get("/image-generations/{image_gen_id}/image")
|
|
async def serve_generated_image(
|
|
image_gen_id: int,
|
|
token: str = Query(..., description="Signed access token"),
|
|
index: int = 0,
|
|
session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
"""
|
|
Serve a generated image by ID, protected by a signed token.
|
|
|
|
The token is generated when the image URL is created by the generate_image
|
|
tool and encodes the image_gen_id, search_space_id, and an expiry timestamp.
|
|
This ensures only users with access to the search space can view images,
|
|
without requiring auth headers (which <img> tags cannot pass).
|
|
|
|
Args:
|
|
image_gen_id: The image generation record ID
|
|
token: HMAC-signed access token (included as query parameter)
|
|
index: Which image to serve if multiple were generated (default: 0)
|
|
"""
|
|
try:
|
|
result = await session.execute(
|
|
select(ImageGeneration).filter(ImageGeneration.id == image_gen_id)
|
|
)
|
|
image_gen = result.scalars().first()
|
|
if not image_gen:
|
|
raise HTTPException(status_code=404, detail="Image generation not found")
|
|
|
|
# Verify the access token against the one stored on the record
|
|
if not verify_image_token(image_gen.access_token, token):
|
|
raise HTTPException(status_code=403, detail="Invalid image access token")
|
|
|
|
if not image_gen.response_data:
|
|
raise HTTPException(status_code=404, detail="No image data available")
|
|
|
|
images = image_gen.response_data.get("data", [])
|
|
if not images or index >= len(images):
|
|
raise HTTPException(
|
|
status_code=404, detail="Image not found at the specified index"
|
|
)
|
|
|
|
image_entry = images[index]
|
|
|
|
# If there's a URL, redirect to it
|
|
if image_entry.get("url"):
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
return RedirectResponse(url=image_entry["url"])
|
|
|
|
# If there's b64_json data, decode and serve it
|
|
if image_entry.get("b64_json"):
|
|
image_bytes = base64.b64decode(image_entry["b64_json"])
|
|
return Response(
|
|
content=image_bytes,
|
|
media_type="image/png",
|
|
headers={
|
|
"Cache-Control": "public, max-age=86400",
|
|
"Content-Disposition": f'inline; filename="generated-{image_gen_id}-{index}.png"',
|
|
},
|
|
)
|
|
|
|
raise HTTPException(status_code=404, detail="No displayable image data")
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Failed to serve generated image")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to serve image: {e!s}"
|
|
) from e
|