fix(chat): harden image generation model routing

This commit is contained in:
Anish Sarkar 2026-06-11 18:22:45 +05:30
parent c28c4f5785
commit 831ad23c6c
7 changed files with 156 additions and 171 deletions

View file

@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402
_OPENROUTER_DYNAMIC_MARKER,
OpenRouterIntegrationService,
)
from app.services.provider_api_base import resolve_api_base # noqa: E402
from app.services.provider_capabilities import ( # noqa: E402
derive_supports_image_input,
is_known_text_only_chat_model,
@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
cap = derive_supports_image_input(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
block = is_known_text_only_chat_model(
provider=cfg.get("provider"),
litellm_provider=cfg.get("litellm_provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
def _build_chat_model_string(cfg: dict) -> str:
if cfg.get("custom_provider"):
return f"{cfg['custom_provider']}/{cfg['model_name']}"
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
prefix = cfg.get("litellm_provider") or "openai"
return f"{prefix}/{cfg['model_name']}"
@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str:
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
model_string = _build_chat_model_string(cfg)
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=model_string.split("/", 1)[0],
config_api_base=cfg.get("api_base") or None,
)
kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"max_tokens": 16,
"timeout": 60,
}
if api_base:
kwargs["api_base"] = api_base
if cfg.get("api_base"):
kwargs["api_base"] = cfg["api_base"]
if cfg.get("litellm_params"):
# Strip pricing keys — they're tracking-only and confuse some
# provider validators (e.g. azure/openai reject unknown kwargs
@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = (
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"""Generate one tiny image to verify the deployment is reachable."""
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
if cfg.get("custom_provider"):
prefix = cfg["custom_provider"]
else:
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
prefix = cfg.get("litellm_provider") or "openai"
model_string = f"{prefix}/{cfg['model_name']}"
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=prefix,
config_api_base=cfg.get("api_base") or None,
)
base_kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"size": "1024x1024",
"timeout": 120,
}
if api_base:
base_kwargs["api_base"] = api_base
if cfg.get("api_base"):
base_kwargs["api_base"] = cfg["api_base"]
if cfg.get("api_version"):
base_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):