SurfSense/surfsense_backend/app/schemas/image_generation.py

233 lines
7.4 KiB
Python

"""
Pydantic schemas for Image Generation configs and generation requests.
ImageGenerationConfig: CRUD schemas for user-created image gen model configs.
ImageGeneration: Schemas for the actual image generation requests/results.
GlobalImageGenConfigRead: Schema for admin-configured YAML configs.
"""
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import ImageGenProvider
# =============================================================================
# ImageGenerationConfig CRUD Schemas
# =============================================================================
class ImageGenerationConfigBase(BaseModel):
"""Base schema with fields for ImageGenerationConfig."""
name: str = Field(
..., max_length=100, description="User-friendly name for the config"
)
description: str | None = Field(
None, max_length=500, description="Optional description"
)
provider: ImageGenProvider = Field(
...,
description="Image generation provider (OpenAI, Azure, Google AI Studio, Vertex AI, Bedrock, Recraft, OpenRouter, Xinference, Nscale)",
)
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name"
)
model_name: str = Field(
..., max_length=100, description="Model name (e.g., dall-e-3, gpt-image-1)"
)
api_key: str = Field(..., description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
api_version: str | None = Field(
None,
max_length=50,
description="Azure-specific API version (e.g., '2024-02-15-preview')",
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
class ImageGenerationConfigCreate(ImageGenerationConfigBase):
"""Schema for creating a new ImageGenerationConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
)
class ImageGenerationConfigUpdate(BaseModel):
"""Schema for updating an existing ImageGenerationConfig. All fields optional."""
name: str | None = Field(None, max_length=100)
description: str | None = Field(None, max_length=500)
provider: ImageGenProvider | None = None
custom_provider: str | None = Field(None, max_length=100)
model_name: str | None = Field(None, max_length=100)
api_key: str | None = None
api_base: str | None = Field(None, max_length=500)
api_version: str | None = Field(None, max_length=50)
litellm_params: dict[str, Any] | None = None
class ImageGenerationConfigRead(ImageGenerationConfigBase):
"""Schema for reading an ImageGenerationConfig (includes id and timestamps)."""
id: int
created_at: datetime
search_space_id: int
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
class ImageGenerationConfigPublic(BaseModel):
"""Public schema that hides the API key (for list views)."""
id: int
name: str
description: str | None = None
provider: ImageGenProvider
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
created_at: datetime
search_space_id: int
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
# =============================================================================
# ImageGeneration (request/result) Schemas
# =============================================================================
class ImageGenerationCreate(BaseModel):
"""Schema for creating an image generation request."""
prompt: str = Field(
...,
min_length=1,
max_length=4000,
description="A text description of the desired image(s)",
)
model: str | None = Field(
None,
max_length=200,
description="The model to use (e.g., 'dall-e-3', 'gpt-image-1'). Overrides the config model.",
)
n: int | None = Field(
None,
ge=1,
le=10,
description="Number of images to generate (1-10).",
)
quality: str | None = Field(None, max_length=50)
size: str | None = Field(None, max_length=50)
style: str | None = Field(None, max_length=50)
response_format: str | None = Field(None, max_length=50)
search_space_id: int = Field(
..., description="Search space ID to associate the generation with"
)
image_generation_config_id: int | None = Field(
None,
description=(
"Image generation config ID. "
"0 = Auto mode (router), negative = global YAML config, positive = DB config. "
"If not provided, uses the search space's image_generation_config_id preference."
),
)
class ImageGenerationRead(BaseModel):
"""Schema for reading an image generation record."""
id: int
prompt: str
model: str | None = None
n: int | None = None
quality: str | None = None
size: str | None = None
style: str | None = None
response_format: str | None = None
image_generation_config_id: int | None = None
response_data: dict[str, Any] | None = None
error_message: str | None = None
search_space_id: int
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class ImageGenerationListRead(BaseModel):
"""Lightweight schema for listing image generations (without full response_data)."""
id: int
prompt: str
model: str | None = None
n: int | None = None
quality: str | None = None
size: str | None = None
search_space_id: int
created_at: datetime
is_success: bool
image_count: int | None = None
model_config = ConfigDict(from_attributes=True)
@classmethod
def from_orm_with_count(cls, obj: Any) -> "ImageGenerationListRead":
"""Create ImageGenerationListRead with computed fields."""
image_count = None
if obj.response_data and isinstance(obj.response_data, dict):
data = obj.response_data.get("data")
if isinstance(data, list):
image_count = len(data)
return cls(
id=obj.id,
prompt=obj.prompt,
model=obj.model,
n=obj.n,
quality=obj.quality,
size=obj.size,
search_space_id=obj.search_space_id,
created_at=obj.created_at,
is_success=obj.response_data is not None,
image_count=image_count,
)
# =============================================================================
# Global Image Gen Config (from YAML)
# =============================================================================
class GlobalImageGenConfigRead(BaseModel):
"""
Schema for reading global image generation configs from YAML.
Global configs have negative IDs. API key is hidden.
ID 0 is reserved for Auto mode (LiteLLM Router load balancing).
"""
id: int = Field(
...,
description="Config ID: 0 for Auto mode, negative for global configs",
)
name: str
description: str | None = None
provider: str
custom_provider: str | None = None
model_name: str
api_base: str | None = None
api_version: str | None = None
litellm_params: dict[str, Any] | None = None
is_global: bool = True
is_auto_mode: bool = False