mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
refactor(images): use model connections for image generation
This commit is contained in:
parent
62ff97c830
commit
077016d6e4
3 changed files with 52 additions and 152 deletions
|
|
@ -10,13 +10,14 @@ from langgraph.types import Command
|
|||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Model,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
|
|
@ -25,37 +26,11 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
|
|
@ -67,13 +42,13 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
image_generation_config_id_override: int | None = None,
|
||||
image_gen_model_id_override: int | None = None,
|
||||
):
|
||||
"""Create ``generate_image`` with bound search space; DB work uses a per-call session.
|
||||
|
||||
``image_generation_config_id_override``: when set (automations running on a
|
||||
captured model), use this config id instead of reading the search space's
|
||||
live ``image_generation_config_id``.
|
||||
``image_gen_model_id_override``: when set (automations running on a
|
||||
captured model), use this model id instead of reading the search space's
|
||||
live ``image_gen_model_id``.
|
||||
"""
|
||||
del db_session # tool uses a fresh per-call session instead
|
||||
|
||||
|
|
@ -118,11 +93,11 @@ def create_generate_image_tool(
|
|||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
if image_generation_config_id_override is not None:
|
||||
if image_gen_model_id_override is not None:
|
||||
# Automation run: use the captured image model, insulated from
|
||||
# later search-space changes. No search-space read needed.
|
||||
config_id = (
|
||||
image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
else:
|
||||
result = await session.execute(
|
||||
|
|
@ -136,7 +111,7 @@ def create_generate_image_tool(
|
|||
)
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id
|
||||
search_space.image_gen_model_id
|
||||
or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
|
|
@ -162,58 +137,35 @@ def create_generate_image_tool(
|
|||
err = f"Image generation config {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(cfg),
|
||||
cfg["model_name"],
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
# Positive ID = Model + Connection
|
||||
cfg_result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
err = f"Image generation config {config_id} not found"
|
||||
db_model = cfg_result.scalars().first()
|
||||
if not db_model or not db_model.connection or not db_model.connection.enabled:
|
||||
err = f"Image generation model {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
if not (db_model.capabilities or {}).get("image_gen"):
|
||||
err = f"Model {config_id} is not image-generation capable"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
db_model.connection,
|
||||
db_model.model_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
|
|
|
|||
|
|
@ -51,8 +51,6 @@ def load_tools(
|
|||
create_generate_image_tool(
|
||||
search_space_id=d["search_space_id"],
|
||||
db_session=d["db_session"],
|
||||
image_generation_config_id_override=d.get(
|
||||
"image_generation_config_id_override"
|
||||
),
|
||||
image_gen_model_id_override=d.get("image_gen_model_id_override"),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -16,11 +16,13 @@ from litellm import aimage_generation
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
Model,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
|
|
@ -46,7 +48,7 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.model_resolver import native_connection_from_config, to_litellm
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -54,22 +56,6 @@ from app.utils.signed_image_urls import verify_image_token
|
|||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping for building litellm model strings.
|
||||
# Only includes providers that support image generation.
|
||||
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini", # Google AI Studio
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock", # AWS Bedrock
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image generation configuration by ID (negative IDs)."""
|
||||
if config_id == IMAGE_GEN_AUTO_MODE_ID:
|
||||
|
|
@ -88,20 +74,6 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
session: AsyncSession,
|
||||
config_id: int | None,
|
||||
|
|
@ -124,7 +96,7 @@ async def _resolve_billing_for_image_gen(
|
|||
"""
|
||||
resolved_id = config_id
|
||||
if resolved_id is None:
|
||||
resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
|
||||
if is_image_gen_auto_mode(resolved_id):
|
||||
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
|
||||
|
|
@ -132,11 +104,7 @@ async def _resolve_billing_for_image_gen(
|
|||
if resolved_id < 0:
|
||||
cfg = _get_global_image_gen_config(resolved_id) or {}
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
base_model = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg.get("model_name", ""),
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", ""))
|
||||
reserve_micros = int(
|
||||
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
|
||||
)
|
||||
|
|
@ -161,7 +129,7 @@ async def _execute_image_generation(
|
|||
"""
|
||||
config_id = image_gen.image_generation_config_id
|
||||
if config_id is None:
|
||||
config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
image_gen.image_generation_config_id = config_id
|
||||
|
||||
# Build kwargs
|
||||
|
|
@ -192,22 +160,11 @@ async def _execute_image_generation(
|
|||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
native_connection_from_config(cfg),
|
||||
cfg["model_name"],
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
|
|
@ -217,30 +174,23 @@ async def _execute_image_generation(
|
|||
prompt=image_gen.prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = DB ImageGenerationConfig
|
||||
# Positive ID = Model + Connection
|
||||
result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id)
|
||||
select(Model)
|
||||
.options(selectinload(Model.connection))
|
||||
.filter(Model.id == config_id, Model.enabled.is_(True))
|
||||
)
|
||||
db_cfg = result.scalars().first()
|
||||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
db_model = result.scalars().first()
|
||||
if not db_model or not db_model.connection or not db_model.connection.enabled:
|
||||
raise ValueError(f"Image generation model {config_id} not found")
|
||||
if not (db_model.capabilities or {}).get("image_gen"):
|
||||
raise ValueError(f"Model {config_id} is not image-generation capable")
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
model_string, resolved_kwargs = to_litellm(
|
||||
db_model.connection,
|
||||
db_model.model_id,
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
gen_kwargs.update(resolved_kwargs)
|
||||
|
||||
# User model override
|
||||
if image_gen.model:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue