mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-22 21:28:12 +02:00
fix(chat): harden image generation model routing
This commit is contained in:
parent
c28c4f5785
commit
831ad23c6c
7 changed files with 156 additions and 171 deletions
|
|
@ -55,7 +55,6 @@ from app.services.openrouter_integration_service import ( # noqa: E402
|
|||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base # noqa: E402
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
|
|
@ -154,13 +153,13 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
cap = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
block = is_known_text_only_chat_model(
|
||||
provider=cfg.get("provider"),
|
||||
litellm_provider=cfg.get("litellm_provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
|
|
@ -179,11 +178,7 @@ def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
|||
def _build_chat_model_string(cfg: dict) -> str:
|
||||
if cfg.get("custom_provider"):
|
||||
return f"{cfg['custom_provider']}/{cfg['model_name']}"
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
prefix = cfg.get("litellm_provider") or "openai"
|
||||
return f"{prefix}/{cfg['model_name']}"
|
||||
|
||||
|
||||
|
|
@ -195,11 +190,6 @@ def _build_chat_model_string(cfg: dict) -> str:
|
|||
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
|
||||
model_string = _build_chat_model_string(cfg)
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=model_string.split("/", 1)[0],
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
|
|
@ -218,8 +208,8 @@ async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
|||
"max_tokens": 16,
|
||||
"timeout": 60,
|
||||
}
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if cfg.get("api_base"):
|
||||
kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("litellm_params"):
|
||||
# Strip pricing keys — they're tracking-only and confuse some
|
||||
# provider validators (e.g. azure/openai reject unknown kwargs
|
||||
|
|
@ -257,20 +247,11 @@ _IMAGE_GEN_PROMPTS: tuple[str, ...] = (
|
|||
|
||||
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Generate one tiny image to verify the deployment is reachable."""
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
if cfg.get("custom_provider"):
|
||||
prefix = cfg["custom_provider"]
|
||||
else:
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
prefix = cfg.get("litellm_provider") or "openai"
|
||||
model_string = f"{prefix}/{cfg['model_name']}"
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=prefix,
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
base_kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
|
|
@ -278,8 +259,8 @@ async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
|||
"size": "1024x1024",
|
||||
"timeout": 120,
|
||||
}
|
||||
if api_base:
|
||||
base_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_base"):
|
||||
base_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
base_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue