fix(chat): harden image generation model routing

This commit is contained in:
Anish Sarkar 2026-06-11 18:22:45 +05:30
parent c28c4f5785
commit 831ad23c6c
7 changed files with 156 additions and 171 deletions

View file

@ -23,20 +23,26 @@ from app.db import (
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.model_resolver import to_litellm
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image gen config by negative ID."""
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def _get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def _get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
def create_generate_image_tool(
@ -93,6 +99,16 @@ def create_generate_image_tool(
# task's session is shared across every tool; without isolation,
# autoflushes from a concurrent writer poison this tool too.
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return _failed(
{"error": "Search space not found"},
error="Search space not found",
)
if image_gen_model_id_override is not None:
# Automation run: use the captured image model, insulated from
# later search-space changes. No search-space read needed.
@ -100,16 +116,6 @@ def create_generate_image_tool(
image_gen_model_id_override or IMAGE_GEN_AUTO_MODE_ID
)
else:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return _failed(
{"error": "Search space not found"},
error="Search space not found",
)
config_id = (
search_space.image_gen_model_id
or IMAGE_GEN_AUTO_MODE_ID
@ -122,24 +128,39 @@ def create_generate_image_tool(
gen_kwargs["n"] = n
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
candidates = await auto_model_candidates(
session,
search_space_id=search_space_id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
err = (
"No image generation models configured. "
"No image generation models available. "
"Please add an image model in Settings > Image Models."
)
return _failed({"error": err}, error=err)
response = await ImageGenRouterService.aimage_generation(
prompt=prompt, model="auto", **gen_kwargs
config_id = int(
choose_auto_model_candidate(candidates, search_space_id)["id"]
)
elif config_id < 0:
cfg = _get_global_image_gen_config(config_id)
if not cfg:
err = f"Image generation config {config_id} not found"
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not (
global_model.get("capabilities") or {}
).get("image_gen"):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
global_connection = _get_global_connection(
global_model["connection_id"]
)
if not global_connection:
err = f"Image generation connection for model {config_id} not found"
return _failed({"error": err}, error=err)
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(cfg),
cfg["model_name"],
global_connection,
global_model["model_id"],
)
gen_kwargs.update(resolved_kwargs)
@ -157,6 +178,19 @@ def create_generate_image_tool(
if not db_model or not db_model.connection or not db_model.connection.enabled:
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
conn = db_model.connection
if (
conn.search_space_id is not None
and conn.search_space_id != search_space_id
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
if (
conn.user_id is not None
and conn.user_id != search_space.user_id
):
err = f"Image generation model {config_id} not found"
return _failed({"error": err}, error=err)
if not (db_model.capabilities or {}).get("image_gen"):
err = f"Model {config_id} is not image-generation capable"
return _failed({"error": err}, error=err)

View file

@ -45,10 +45,13 @@ from app.services.billable_calls import (
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.model_resolver import native_connection_from_config, to_litellm
from app.services.auto_model_pin_service import (
auto_model_candidates,
choose_auto_model_candidate,
)
from app.services.model_resolver import to_litellm
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -56,22 +59,15 @@ from app.utils.signed_image_urls import verify_image_token
router = APIRouter()
logger = logging.getLogger(__name__)
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image generation configuration by ID (negative IDs)."""
if config_id == IMAGE_GEN_AUTO_MODE_ID:
return {
"id": IMAGE_GEN_AUTO_MODE_ID,
"name": "Auto (Fastest)",
"provider": "AUTO",
"model_name": "auto",
"is_auto_mode": True,
}
if config_id > 0:
return None
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def _get_global_model(model_id: int) -> dict | None:
return next((m for m in config.GLOBAL_MODELS if m.get("id") == model_id), None)
def _get_global_connection(connection_id: int) -> dict | None:
return next(
(c for c in config.GLOBAL_CONNECTIONS if c.get("id") == connection_id),
None,
)
async def _resolve_billing_for_image_gen(
@ -87,30 +83,41 @@ async def _resolve_billing_for_image_gen(
config that will actually run, and so we don't open an
``ImageGeneration`` row for a request that's about to 402.
User-owned (positive ID) BYOK configs are always free they cost
the user nothing on our side. Auto mode currently treats as free
because the underlying router can dispatch to either premium or
free YAML configs and we don't surface the resolved deployment up
here yet. Bringing Auto under premium billing would require
threading the chosen deployment back from ``ImageGenRouterService``.
User-owned (positive ID) BYOK models are always free they cost
the user nothing on our side. Auto mode resolves to one concrete
global or BYOK model before billing is calculated.
"""
resolved_id = config_id
if resolved_id is None:
resolved_id = search_space.image_gen_model_id or IMAGE_GEN_AUTO_MODE_ID
if is_image_gen_auto_mode(resolved_id):
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
if not candidates:
return ("free", "auto", DEFAULT_IMAGE_RESERVE_MICROS)
selected = choose_auto_model_candidate(candidates, search_space.id)
resolved_id = int(selected["id"])
if resolved_id < 0:
cfg = _get_global_image_gen_config(resolved_id) or {}
billing_tier = str(cfg.get("billing_tier", "free")).lower()
base_model, _ = to_litellm(native_connection_from_config(cfg), cfg.get("model_name", ""))
global_model = _get_global_model(resolved_id) or {}
global_connection = _get_global_connection(global_model.get("connection_id", 0))
billing_tier = str(global_model.get("billing_tier", "free")).lower()
if global_connection and global_model.get("model_id"):
base_model, _ = to_litellm(global_connection, global_model["model_id"])
else:
base_model = "global_image_model"
catalog = global_model.get("catalog") or {}
reserve_micros = int(
cfg.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
catalog.get("quota_reserve_micros") or DEFAULT_IMAGE_RESERVE_MICROS
)
return (billing_tier, base_model, reserve_micros)
# Positive ID = user-owned BYOK image-gen config — always free.
# Positive ID = user-owned BYOK image-gen model — always free.
return ("free", "user_byok", DEFAULT_IMAGE_RESERVE_MICROS)
@ -146,23 +153,28 @@ async def _execute_image_generation(
gen_kwargs["response_format"] = image_gen.response_format
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
raise ValueError(
"Auto mode requested but Image Generation Router not initialized. "
"Ensure global_llm_config.yaml has global_image_generation_configs."
)
response = await ImageGenRouterService.aimage_generation(
prompt=image_gen.prompt, model="auto", **gen_kwargs
candidates = await auto_model_candidates(
session,
search_space_id=search_space.id,
user_id=search_space.user_id,
capability="image_gen",
)
elif config_id < 0:
# Global config from YAML
cfg = _get_global_image_gen_config(config_id)
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
if not candidates:
raise ValueError("No image-generation models are available for Auto mode")
config_id = int(choose_auto_model_candidate(candidates, search_space.id)["id"])
image_gen.image_generation_config_id = config_id
if config_id < 0:
global_model = _get_global_model(config_id)
if not global_model or not (global_model.get("capabilities") or {}).get("image_gen"):
raise ValueError(f"Global image generation model {config_id} not found")
global_connection = _get_global_connection(global_model["connection_id"])
if not global_connection:
raise ValueError(f"Global connection for image model {config_id} not found")
model_string, resolved_kwargs = to_litellm(
native_connection_from_config(cfg),
cfg["model_name"],
global_connection,
global_model["model_id"],
)
gen_kwargs.update(resolved_kwargs)
@ -183,6 +195,11 @@ async def _execute_image_generation(
db_model = result.scalars().first()
if not db_model or not db_model.connection or not db_model.connection.enabled:
raise ValueError(f"Image generation model {config_id} not found")
conn = db_model.connection
if conn.search_space_id is not None and conn.search_space_id != search_space.id:
raise ValueError(f"Image generation model {config_id} not found")
if conn.user_id is not None and conn.user_id != search_space.user_id:
raise ValueError(f"Image generation model {config_id} not found")
if not (db_model.capabilities or {}).get("image_gen"):
raise ValueError(f"Model {config_id} is not image-generation capable")
@ -255,7 +272,7 @@ async def get_global_image_gen_configs(
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"provider": cfg.get("litellm_provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,

View file

@ -20,23 +20,13 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
from app.services.model_resolver import (
NATIVE_PROVIDER_PREFIX,
native_connection_from_config,
to_litellm,
)
from app.services.model_resolver import native_connection_from_config, to_litellm
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
IMAGE_GEN_AUTO_MODE_ID = 0
# Provider mapping for LiteLLM model string construction.
# Only includes providers that support image generation.
# See: https://docs.litellm.ai/docs/image_generation#supported-providers
IMAGE_GEN_PROVIDER_MAP = NATIVE_PROVIDER_PREFIX
class ImageGenRouterService:
"""
Singleton service for managing LiteLLM Router for image generation.

View file

@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402
_OPENROUTER_DYNAMIC_MARKER,
OpenRouterIntegrationService,
)
from app.services.provider_api_base import resolve_api_base # noqa: E402
from app.services.provider_capabilities import ( # noqa: E402
derive_supports_image_input,
is_known_text_only_chat_model,
@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
cap = derive_supports_image_input(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
block = is_known_text_only_chat_model(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
def _build_chat_model_string(cfg: dict) -> str:
if cfg.get("custom_provider"):
return f"{cfg['custom_provider']}/{cfg['model_name']}"
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
prefix = cfg.get("litellm_provider") or "openai"
return f"{prefix}/{cfg['model_name']}"
@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str:
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
model_string = _build_chat_model_string(cfg)
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=model_string.split("/", 1)[0],
config_api_base=cfg.get("api_base") or None,
)
kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"max_tokens": 16,
"timeout": 60,
}
if api_base:
kwargs["api_base"] = api_base
if cfg.get("api_base"):
kwargs["api_base"] = cfg["api_base"]
if cfg.get("litellm_params"):
# Strip pricing keys — they're tracking-only and confuse some
# provider validators (e.g. azure/openai reject unknown kwargs
@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = (
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"""Generate one tiny image to verify the deployment is reachable."""
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
if cfg.get("custom_provider"):
prefix = cfg["custom_provider"]
else:
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
prefix = cfg.get("litellm_provider") or "openai"
model_string = f"{prefix}/{cfg['model_name']}"
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=prefix,
config_api_base=cfg.get("api_base") or None,
)
base_kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"size": "1024x1024",
"timeout": 120,
}
if api_base:
base_kwargs["api_base"] = api_base
if cfg.get("api_base"):
base_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
base_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):

View file

@ -49,14 +49,14 @@ async def test_resolve_billing_for_premium_global_config(monkeypatch):
[
{
"id": -1,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-image-1",
"billing_tier": "premium",
"quota_reserve_micros": 75_000,
},
{
"id": -2,
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "google/gemini-2.5-flash-image",
"billing_tier": "free",
},
@ -118,7 +118,7 @@ async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
[
{
"id": -7,
"provider": "OPENAI",
"litellm_provider": "openai",
"model_name": "gpt-image-1",
"billing_tier": "premium",
}

View file

@ -1,19 +1,4 @@
"""Defense-in-depth: image-gen call sites must not let an empty
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
The bug repro: an OpenRouter image-gen config ships
``api_base=""``. The pre-fix call site in
``image_generation_routes._execute_image_generation`` did
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
silently dropped the empty string. LiteLLM then fell back to
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
and OpenRouter's ``image_generation/transformation`` appended
``/chat/completions`` to it 404 ``Resource not found``.
This test pins the post-fix behaviour: with an empty ``api_base`` in
the config, the call site MUST set ``api_base`` to OpenRouter's public
URL instead of leaving it unset.
"""
"""Image-gen call sites must pass each config's explicit ``api_base``."""
from __future__ import annotations
@ -26,20 +11,17 @@ pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
"""The global-config branch (``config_id < 0``) of
``_execute_image_generation`` must apply the resolver and pin
``api_base`` to OpenRouter when the config ships an empty string.
"""
async def test_global_openrouter_image_gen_sets_explicit_api_base():
"""The global-config branch forwards the explicit OpenRouter base."""
from app.routes import image_generation_routes
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "", # the original bug shape
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"litellm_params": {},
}
@ -80,16 +62,13 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
session=session, image_gen=image_gen, search_space=search_space
)
# The whole point of the fix: even with empty ``api_base`` in the
# config, we forward OpenRouter's public URL so the call doesn't
# inherit an Azure endpoint.
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
@pytest.mark.asyncio
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
"""Same defense at the agent tool entry point — both surfaces share
async def test_generate_image_tool_global_sets_explicit_api_base():
"""Same explicit-base behavior at the agent tool entry point — both surfaces share
the same OpenRouter config payloads."""
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools import (
generate_image as gi_module,
@ -98,10 +77,10 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"litellm_params": {},
}
@ -171,20 +150,16 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
assert captured["model"] == "openrouter/openai/gpt-image-1"
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
"""The Auto-mode router pool must also resolve ``api_base`` when an
OpenRouter config ships an empty string. The deployment dict is fed
straight to ``litellm.Router``, so a missing ``api_base`` would
leak the same way as the direct call sites.
"""
def test_image_gen_router_deployment_sets_explicit_api_base():
"""The Auto-mode router pool carries explicit api_base into deployments."""
from app.services.image_gen_router_service import ImageGenRouterService
deployment = ImageGenRouterService._config_to_deployment(
{
"model_name": "openai/gpt-image-1",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
}
)
assert deployment is not None

View file

@ -1,12 +1,4 @@
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
defaults from ``litellm.api_base`` either.
Vision shares the same shape as image-gen global YAML / OpenRouter
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
call sites would silently drop the empty string and inherit
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
construction so we test the kwargs we hand to it instead.
"""
"""Vision LLM resolution must pass explicit per-config ``api_base``."""
from __future__ import annotations
@ -19,19 +11,16 @@ pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_get_vision_llm_global_openrouter_sets_api_base():
"""Global negative-ID branch: an OpenRouter vision config with
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
``api_base="https://openrouter.ai/api/v1"`` never an empty string,
never silently absent."""
"""Global negative-ID branch forwards the explicit OpenRouter base."""
from app.services import llm_service
cfg = {
"id": -30_001,
"name": "GPT-4o Vision (OpenRouter)",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"litellm_params": {},
"billing_tier": "free",
@ -72,16 +61,15 @@ async def test_get_vision_llm_global_openrouter_sets_api_base():
def test_vision_router_deployment_sets_api_base_when_config_empty():
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
so the resolver has to apply at deployment construction time too."""
"""Auto-mode vision router carries explicit api_base into deployments."""
from app.services.vision_llm_router_service import VisionLLMRouterService
deployment = VisionLLMRouterService._config_to_deployment(
{
"model_name": "openai/gpt-4o",
"provider": "OPENROUTER",
"litellm_provider": "openrouter",
"api_key": "sk-or-test",
"api_base": "",
"api_base": "https://openrouter.ai/api/v1",
}
)
assert deployment is not None