refactor(images): use model connections for image generation

This commit is contained in:
Anish Sarkar 2026-06-10 21:48:37 +05:30
parent 62ff97c830
commit 077016d6e4
3 changed files with 52 additions and 152 deletions

View file

@ -10,13 +10,14 @@ from langgraph.types import Command
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
Model,
SearchSpace,
shielded_async_session,
)
@ -25,37 +26,11 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
# Provider mapping (same as routes)
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini",
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock",
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image gen config by negative ID."""
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
@ -67,13 +42,13 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
def create_generate_image_tool(
search_space_id: int,
db_session: AsyncSession,
image_generation_config_id_override: int | None = None,
image_gen_model_id_override: int | None = None,
):
"""Create ``generate_image`` with bound search space; DB work uses a per-call session.
``image_generation_config_id_override``: when set (automations running on a
captured model), use this config id instead of reading the search space's
live ``image_generation_config_id``.
``image_gen_model_id_override``: when set (automations running on a
captured model), use this model id instead of reading the search space's
live ``image_gen_model_id``.
"""
del db_session # tool uses a fresh per-call session instead
@ -118,11 +93,11 @@ def create_generate_image_tool(
# task's session is shared across every tool; without isolation,
# autoflushes from a concurrent writer poison this tool too.
async with shielded_async_session() as session:
if image_generation_config_id_override is not None:
if image_gen_model_id_override is not None:
# Automation run: use the captured image model, insulated from
# later search-space changes. No search-space read needed.
config_id = (
image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID
image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
)
else:
result = await session.execute(
@ -136,7 +111,7 @@ def create_generate_image_tool(
)
config_id = (
search_space.image_generation_config_id
search_space.image_gen_model_id
or IMAGE_GEN_AUTO_MODE_ID
)
@ -162,58 +137,35 @@ def create_generate_image_tool(
err = f"Image generation config {config_id} not found"
return _failed({"error": err}, error=err)
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(cfg),
cfg["model_name"],
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
# Defense-in-depth: an empty ``api_base`` must not fall
# through to LiteLLM's global ``api_base`` (e.g. Azure).
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
gen_kwargs.update(resolved_kwargs)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = user-created ImageGenerationConfig
# Positive ID = Model + Connection
cfg_result = await session.execute(
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id
)
select(Model)
.options(selectinload(Model.connection))
.filter(Model.id == config_id, Model.enabled.is_(True))
)
db_cfg = cfg_result.scalars().first()
if not db_cfg:
err = f"Image generation config {config_id} not found"
db_model = cfg_result.scalars().first()
if not db_model or not db_model.connection or not db_model.connection.enabled:
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
if not (db_model.capabilities or {}).get("image_gen"):
err = f"Model {config_id} is not image-generation capable"
return _failed({"error": err}, error=err)
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
model_string, resolved_kwargs = to_litellm(
db_model.connection,
db_model.model_id,
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
# Defense-in-depth: an empty ``api_base`` must not fall
# through to LiteLLM's global ``api_base`` (e.g. Azure).
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
gen_kwargs.update(resolved_kwargs)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs

View file

@ -51,8 +51,6 @@ def load_tools(
create_generate_image_tool(
search_space_id=d["search_space_id"],
db_session=d["db_session"],
image_generation_config_id_override=d.get(
"image_generation_config_id_override"
),
image_gen_model_id_override=d.get("image_gen_model_id_override"),
),
]

View file

@ -16,11 +16,13 @@ from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
Model,
Permission,
SearchSpace,
SearchSpaceMembership,
@ -46,7 +48,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -54,22 +56,6 @@ from app.utils.signed_image_urls import verify_image_token
router = APIRouter()
logger = logging.getLogger(__name__)
# Provider mapping for building litellm model strings.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini", # Google AI Studio
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock", # AWS Bedrock
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image generation configuration by ID (negative IDs)."""
if config_id == IMAGE_GEN_AUTO_MODE_ID:
@ -88,20 +74,6 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
"""Resolve the LiteLLM provider prefix used in model strings."""
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
async def _resolve_billing_for_image_gen(
session: AsyncSession,
config_id: int | None,
@ -124,7 +96,7 @@ async def _resolve_billing_for_image_gen(
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
@ -132,11 +104,7 @@ async def _resolve_billing_for_image_gen(
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model = _build_model_string(
cfg.get("provider", ""),
cfg.get("model_name", ""),
cfg.get("custom_provider"),
)
base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", ""))
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
@ -161,7 +129,7 @@ async def _execute_image_generation(
"""
config_id = image_gen.image_generation_config_id
if config_id is None:
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
image_gen.image_generation_config_id = config_id
# Build kwargs
@ -192,22 +160,11 @@ async def _execute_image_generation(
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(cfg),
cfg["model_name"],
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
gen_kwargs.update(resolved_kwargs)
# User model override
if image_gen.model:
@ -217,30 +174,23 @@ async def _execute_image_generation(
prompt=image_gen.prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = DB ImageGenerationConfig
# Positive ID = Model + Connection
result = await session.execute(
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
select(Model)
.options(selectinload(Model.connection))
.filter(Model.id == config_id, Model.enabled.is_(True))
)
db_cfg = result.scalars().first()
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
db_model = result.scalars().first()
if not db_model or not db_model.connection or not db_model.connection.enabled:
raise ValueError(f"Image generation model {config_id} not found")
if not (db_model.capabilities or {}).get("image_gen"):
raise ValueError(f"Model {config_id} is not image-generation capable")
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
model_string, resolved_kwargs = to_litellm(
db_model.connection,
db_model.model_id,
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
gen_kwargs.update(resolved_kwargs)
# User model override
if image_gen.model: