diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 7bb4a7c24..dd980c51c 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -10,13 +10,14 @@ from langgraph.types import Command from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt from app.config import config from app.db import ( ImageGeneration, - ImageGenerationConfig, + Model, SearchSpace, shielded_async_session, ) @@ -25,37 +26,11 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) -# Provider mapping (same as routes) -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" - - def _get_global_image_gen_config(config_id: int) -> dict | None: """Get a global image gen config by negative ID.""" for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: @@ -67,13 +42,13 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: def create_generate_image_tool( search_space_id: int, db_session: AsyncSession, - image_generation_config_id_override: int | None = None, + image_gen_model_id_override: int | None = None, ): """Create ``generate_image`` with bound search space; DB work uses a per-call session. - ``image_generation_config_id_override``: when set (automations running on a - captured model), use this config id instead of reading the search space's - live ``image_generation_config_id``. + ``image_gen_model_id_override``: when set (automations running on a + captured model), use this model id instead of reading the search space's + live ``image_gen_model_id``. """ del db_session # tool uses a fresh per-call session instead @@ -118,11 +93,11 @@ def create_generate_image_tool( # task's session is shared across every tool; without isolation, # autoflushes from a concurrent writer poison this tool too. async with shielded_async_session() as session: - if image_generation_config_id_override is not None: + if image_gen_model_id_override is not None: # Automation run: use the captured image model, insulated from # later search-space changes. No search-space read needed. config_id = ( - image_generation_config_id_override or IMAGE_GEN_AUTO_MODE_ID + image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID ) else: result = await session.execute( @@ -136,7 +111,7 @@ def create_generate_image_tool( ) config_id = ( - search_space.image_generation_config_id + search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID ) @@ -162,58 +137,35 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(cfg), + cfg["model_name"], ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - # Defense-in-depth: an empty ``api_base`` must not fall - # through to LiteLLM's global ``api_base`` (e.g. Azure). - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) + gen_kwargs.update(resolved_kwargs) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = user-created ImageGenerationConfig + # Positive ID = Model + Connection cfg_result = await session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id - ) + select(Model) + .options(selectinload(Model.connection)) + .filter(Model.id == config_id, Model.enabled.is_(True)) ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - err = f"Image generation config {config_id} not found" + db_model = cfg_result.scalars().first() + if not db_model or not db_model.connection or not db_model.connection.enabled: + err = f"Image generation model {config_id} not found" + return _failed({"error": err}, error=err) + if not (db_model.capabilities or {}).get("image_gen"): + err = f"Model {config_id} is not image-generation capable" return _failed({"error": err}, error=err) - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider + model_string, resolved_kwargs = to_litellm( + db_model.connection, + db_model.model_id, ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - # Defense-in-depth: an empty ``api_base`` must not fall - # through to LiteLLM's global ``api_base`` (e.g. Azure). - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + gen_kwargs.update(resolved_kwargs) response = await aimage_generation( prompt=prompt, model=model_string, **gen_kwargs diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py index b968c1701..8de95f2df 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/subagents/builtins/deliverables/tools/index.py @@ -51,8 +51,6 @@ def load_tools( create_generate_image_tool( search_space_id=d["search_space_id"], db_session=d["db_session"], - image_generation_config_id_override=d.get( - "image_generation_config_id_override" - ), + image_gen_model_id_override=d.get("image_gen_model_id_override"), ), ] diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 018234ad5..0de368d57 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -16,11 +16,13 @@ from litellm import aimage_generation from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.config import config from app.db import ( ImageGeneration, ImageGenerationConfig, + Model, Permission, SearchSpace, SearchSpaceMembership, @@ -46,7 +48,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) -from app.services.provider_api_base import resolve_api_base +from app.services.model_resolver import native_connection_from_config, to_litellm from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -54,22 +56,6 @@ from app.utils.signed_image_urls import verify_image_token router = APIRouter() logger = logging.getLogger(__name__) -# Provider mapping for building litellm model strings. -# Only includes providers that support image generation. -# See: https://docs.litellm.ai/docs/image_generation#supported-providers -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", # Google AI Studio - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", # AWS Bedrock - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - def _get_global_image_gen_config(config_id: int) -> dict | None: """Get a global image generation configuration by ID (negative IDs).""" if config_id == IMAGE_GEN_AUTO_MODE_ID: @@ -88,20 +74,6 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - """Resolve the LiteLLM provider prefix used in model strings.""" - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - """Build a litellm model string from provider + model_name.""" - return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" - - async def _resolve_billing_for_image_gen( session: AsyncSession, config_id: int | None, @@ -124,7 +96,7 @@ async def _resolve_billing_for_image_gen( """ resolved_id = config_id if resolved_id is None: - resolved_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID if is_image_gen_auto_mode(resolved_id): return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS) @@ -132,11 +104,7 @@ async def _resolve_billing_for_image_gen( if resolved_id < 0: cfg = _get_global_image_gen_config(resolved_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() - base_model = _build_model_string( - cfg.get("provider", ""), - cfg.get("model_name", ""), - cfg.get("custom_provider"), - ) + base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", "")) reserve_micros = int( cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS ) @@ -161,7 +129,7 @@ async def _execute_image_generation( """ config_id = image_gen.image_generation_config_id if config_id is None: - config_id = search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID + config_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID image_gen.image_generation_config_id = config_id # Build kwargs @@ -192,22 +160,11 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") + model_string, resolved_kwargs = to_litellm( + native_connection_from_config(cfg), + cfg["model_name"], ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) + gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: @@ -217,30 +174,23 @@ async def _execute_image_generation( prompt=image_gen.prompt, model=model_string, **gen_kwargs ) else: - # Positive ID = DB ImageGenerationConfig + # Positive ID = Model + Connection result = await session.execute( - select(ImageGenerationConfig).filter(ImageGenerationConfig.id == config_id) + select(Model) + .options(selectinload(Model.connection)) + .filter(Model.id == config_id, Model.enabled.is_(True)) ) - db_cfg = result.scalars().first() - if not db_cfg: - raise ValueError(f"Image generation config {config_id} not found") + db_model = result.scalars().first() + if not db_model or not db_model.connection or not db_model.connection.enabled: + raise ValueError(f"Image generation model {config_id} not found") + if not (db_model.capabilities or {}).get("image_gen"): + raise ValueError(f"Model {config_id} is not image-generation capable") - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider + model_string, resolved_kwargs = to_litellm( + db_model.connection, + db_model.model_id, ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) + gen_kwargs.update(resolved_kwargs) # User model override if image_gen.model: