feat: fixed vision/image provider specific errors and fixed podcast/video streaming
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-02 19:18:53 -07:00
parent ae9d36d77f
commit 47b2994ec7
54 changed files with 4469 additions and 563 deletions

View file

@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
yield chunk
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
# Provider mapping for LiteLLM model string construction.
#
# Single source of truth lives in
# :mod:`app.services.provider_capabilities` so the YAML loader (which
# runs during ``app.config`` class-body init) can resolve provider
# prefixes without dragging the agent / tools tree into module load
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
# tests) keep working unchanged.
from app.services.provider_capabilities import ( # noqa: E402
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
)
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
@ -178,6 +155,17 @@ class AgentConfig:
anonymous_enabled: bool = False
quota_reserve_tokens: int | None = None
# Capability flag: best-effort True for the chat selector / catalog.
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
# which prefers OpenRouter's ``architecture.input_modalities`` and
# otherwise consults LiteLLM's authoritative model map. Default True
# is the conservative-allow stance — the streaming-task safety net
# (``is_known_text_only_chat_model``) is the *only* place a False
# actually blocks a request. Setting this to False here without an
# authoritative source would silently hide vision-capable models
# (the regression we're fixing).
supports_image_input: bool = True
@classmethod
def from_auto_mode(cls) -> "AgentConfig":
"""
@ -203,6 +191,12 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# Auto routes across the configured pool, which usually
# contains at least one vision-capable deployment; the router
# will surface a 404 from a non-vision deployment as a normal
# ``allowed_fails`` event and fail over rather than blocking
# the request outright.
supports_image_input=True,
)
@classmethod
@ -216,10 +210,24 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
return cls(
provider=config.provider.value
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider),
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
return cls(
provider=provider_value,
model_name=config.model_name,
api_key=config.api_key,
api_base=config.api_base,
@ -235,6 +243,16 @@ class AgentConfig:
is_premium=False,
anonymous_enabled=False,
quota_reserve_tokens=None,
# BYOK rows have no operator-curated capability flag, so we
# ask LiteLLM (default-allow on unknown). The streaming
# safety net still blocks if the model is *explicitly*
# marked text-only.
supports_image_input=derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
),
)
@classmethod
@ -253,15 +271,46 @@ class AgentConfig:
Returns:
AgentConfig instance
"""
# Lazy import to avoid pulling provider_capabilities (and its
# transitive litellm import) into module-init order.
from app.services.provider_capabilities import derive_supports_image_input
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "")
provider = yaml_config.get("provider", "").upper()
model_name = yaml_config.get("model_name", "")
custom_provider = yaml_config.get("custom_provider")
litellm_params = yaml_config.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
# Explicit YAML override wins; otherwise derive from LiteLLM /
# OpenRouter modalities. The YAML loader already populates this
# field, but this method is also called from
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
# so we re-derive here for safety. The bool() coercion preserves
# the loader's behaviour for explicit ``true`` / ``false``
# strings that PyYAML may surface.
if "supports_image_input" in yaml_config:
supports_image_input = bool(yaml_config.get("supports_image_input"))
else:
supports_image_input = derive_supports_image_input(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
)
return cls(
provider=yaml_config.get("provider", "").upper(),
model_name=yaml_config.get("model_name", ""),
provider=provider,
model_name=model_name,
api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"),
custom_provider=yaml_config.get("custom_provider"),
custom_provider=custom_provider,
litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None,
@ -276,6 +325,7 @@ class AgentConfig:
is_premium=yaml_config.get("billing_tier", "free") == "premium",
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
supports_image_input=supports_image_input,
)

View file

@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
}
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{prefix}/{model_name}"
@ -146,14 +151,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
model_string = _build_model_string(
cfg.get("provider", ""),
cfg["model_name"],
cfg.get("custom_provider"),
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@ -175,14 +184,18 @@ def create_generate_image_tool(
"error": f"Image generation config {config_id} not found"
}
model_string = _build_model_string(
db_cfg.provider.value,
db_cfg.model_name,
db_cfg.custom_provider,
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:

View file

@ -47,11 +47,37 @@ def load_global_llm_configs():
data = yaml.safe_load(f)
configs = data.get("global_llm_configs", [])
# Lazy import keeps the `app.config` -> `app.services` edge one-way
# and matches the `provider_api_base` pattern used elsewhere.
from app.services.provider_capabilities import derive_supports_image_input
seen_slugs: dict[str, int] = {}
for cfg in configs:
cfg.setdefault("billing_tier", "free")
cfg.setdefault("anonymous_enabled", False)
cfg.setdefault("seo_enabled", False)
# Capability flag: explicit YAML override always wins. When the
# operator has not annotated the model, defer to LiteLLM's
# authoritative model map (`supports_vision`) which already
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
# vision-capable. Unknown / unmapped models default-allow so
# we don't lock the user out of a freshly added third-party
# entry; the streaming-task safety net (driven by
# `is_known_text_only_chat_model`) is the only place a False
# actually blocks a request.
if "supports_image_input" not in cfg:
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model")
if isinstance(litellm_params, dict)
else None
)
cfg["supports_image_input"] = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
slug = cfg["seo_slug"]

View file

@ -46,6 +46,7 @@ from app.services.image_gen_router_service import (
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token
@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
return None
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
"""Resolve the LiteLLM provider prefix used in model strings."""
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
"""Build a litellm model string from provider + model_name."""
if custom_provider:
return f"{custom_provider}/{model_name}"
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
return f"{prefix}/{model_name}"
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
async def _resolve_billing_for_image_gen(
@ -187,12 +192,18 @@ async def _execute_image_generation(
if not cfg:
raise ValueError(f"Global image generation config {config_id} not found")
model_string = _build_model_string(
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
if cfg.get("api_base"):
gen_kwargs["api_base"] = cfg["api_base"]
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
@ -214,12 +225,18 @@ async def _execute_image_generation(
if not db_cfg:
raise ValueError(f"Image generation config {config_id} not found")
model_string = _build_model_string(
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
if db_cfg.api_base:
gen_kwargs["api_base"] = db_cfg.api_base
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
@ -277,10 +294,12 @@ async def get_global_image_gen_configs(
# Auto mode currently treated as free until per-deployment
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@ -293,7 +312,11 @@ async def get_global_image_gen_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
}
)

View file

@ -29,6 +29,7 @@ from app.schemas import (
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
from app.services.provider_capabilities import derive_supports_image_input
from app.users import current_active_user
from app.utils.rbac import check_permission
@ -36,6 +37,39 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
There is no DB column for ``supports_image_input`` the value is
resolved at the API boundary from LiteLLM's authoritative model map
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
the response shape consistent across list / detail / create / update
endpoints without having to remember to set the field at every call
site.
"""
provider_value = (
config.provider.value
if hasattr(config.provider, "value")
else str(config.provider)
)
litellm_params = config.litellm_params or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
supports_image_input = derive_supports_image_input(
provider=provider_value,
model_name=config.model_name,
base_model=base_model,
custom_provider=config.custom_provider,
)
# ``model_validate`` runs the Pydantic conversion using the ORM
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
# then we layer the derived field on. ``model_copy(update=...)`` keeps
# the surface immutable from the caller's perspective.
base_read = NewLLMConfigRead.model_validate(config)
return base_read.model_copy(update={"supports_image_input": supports_image_input})
# =============================================================================
# Global Configs Routes
# =============================================================================
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
"seo_title": None,
"seo_description": None,
"quota_reserve_tokens": None,
# Auto routes across the configured pool, which usually
# includes at least one vision-capable deployment, so
# treat Auto as image-capable. The router itself will
# still pick a vision-capable deployment for messages
# carrying image_url blocks (LiteLLM Router falls back
# on ``404`` per its ``allowed_fails`` policy).
"supports_image_input": True,
}
)
# Add individual global configs
for cfg in global_configs:
# Capability resolution: explicit value (YAML override or OR
# `_supports_image_input(model)` payload baked in by the
# OpenRouter integration service) wins. Fall back to the
# LiteLLM-driven helper which default-allows on unknown so
# we don't hide vision-capable models that happen to lack a
# YAML annotation. The streaming task safety net is the
# only place a False ever blocks.
if "supports_image_input" in cfg:
supports_image_input = bool(cfg.get("supports_image_input"))
else:
cfg_litellm_params = cfg.get("litellm_params") or {}
cfg_base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
supports_image_input = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=cfg_base_model,
custom_provider=cfg.get("custom_provider"),
)
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
"seo_title": cfg.get("seo_title"),
"seo_description": cfg.get("seo_description"),
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"supports_image_input": supports_image_input,
}
safe_configs.append(safe_config)
@ -171,7 +236,7 @@ async def create_new_llm_config(
await session.commit()
await session.refresh(db_config)
return db_config
return _serialize_byok_config(db_config)
except HTTPException:
raise
@ -213,7 +278,7 @@ async def list_new_llm_configs(
.limit(limit)
)
return result.scalars().all()
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
except HTTPException:
raise
@ -268,7 +333,7 @@ async def get_new_llm_config(
"You don't have permission to view LLM configurations in this search space",
)
return config
return _serialize_byok_config(config)
except HTTPException:
raise
@ -360,7 +425,7 @@ async def update_new_llm_config(
await session.commit()
await session.refresh(config)
return config
return _serialize_byok_config(config)
except HTTPException:
raise

View file

@ -85,10 +85,12 @@ async def get_global_vision_llm_configs(
# Auto mode treated as free until per-deployment billing-tier
# surfacing lands; see ``get_vision_llm`` for parity.
"billing_tier": "free",
"is_premium": False,
}
)
for cfg in global_configs:
billing_tier = str(cfg.get("billing_tier", "free")).lower()
safe_configs.append(
{
"id": cfg.get("id"),
@ -101,7 +103,11 @@ async def get_global_vision_llm_configs(
"api_version": cfg.get("api_version") or None,
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
"billing_tier": cfg.get("billing_tier", "free"),
"billing_tier": billing_tier,
# Mirror chat (``new_llm_config_routes``) so the new-chat
# selector's premium badge logic keys off the same
# field across chat / image / vision tabs.
"is_premium": billing_tier == "premium",
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
"input_cost_per_token": cfg.get("input_cost_per_token"),
"output_cost_per_token": cfg.get("output_cost_per_token"),

View file

@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel):
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_micros: int | None = Field(
default=None,
description=(

View file

@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
# Capability flag derived at the API boundary (no DB column). Default
# True matches the conservative-allow stance — a BYOK row that the
# route forgot to augment is not pre-judged. The streaming-task
# safety net is the only place a False actually blocks a request.
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map "
"(``litellm.supports_vision``) — there is no DB column. "
"Default True is the conservative-allow stance for unknown / "
"unmapped models."
),
)
model_config = ConfigDict(from_attributes=True)
@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
created_at: datetime
search_space_id: int
user_id: uuid.UUID
# Capability flag derived at the API boundary (see NewLLMConfigRead).
supports_image_input: bool = Field(
default=True,
description=(
"Whether the BYOK chat config can accept image inputs. Derived "
"at the route boundary from LiteLLM's authoritative model map. "
"Default True is the conservative-allow stance."
),
)
model_config = ConfigDict(from_attributes=True)
@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
seo_title: str | None = None
seo_description: str | None = None
quota_reserve_tokens: int | None = None
supports_image_input: bool = Field(
default=True,
description=(
"Whether the model accepts image inputs (multimodal vision). "
"Derived server-side: OpenRouter dynamic configs use "
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
"authoritative model map (``litellm.supports_vision``). The "
"new-chat selector hints with a 'No image' badge when this is "
"False and there are pending image attachments. The streaming "
"task fails fast only when LiteLLM *explicitly* marks a model "
"as text-only — unknown / unmapped models default-allow."
),
)
# =============================================================================

View file

@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel):
default="free",
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
)
is_premium: bool = Field(
default=False,
description=(
"Convenience boolean derived server-side from "
"``billing_tier == 'premium'``. The new-chat model selector "
"keys its Free/Premium badge off this field for parity with "
"chat (`GlobalLLMConfigRead.is_premium`)."
),
)
quota_reserve_tokens: int | None = Field(
default=None,
description=(

View file

@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
_healthy_until.pop(int(config_id), None)
def _global_candidates() -> list[dict]:
def _cfg_supports_image_input(cfg: dict) -> bool:
"""True if the global cfg can accept image inputs.
Prefers the explicit ``supports_image_input`` flag (set by the YAML
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
a YAML entry whose flag was somehow stripped doesn't get wrongly
excluded. Default-allows on unknown the streaming-task safety net
is the actual block, not this filter.
"""
if "supports_image_input" in cfg:
return bool(cfg.get("supports_image_input"))
# Lazy import: provider_capabilities -> llm_config -> services chain;
# importing at module load would create an init-order cycle through
# ``app.config``.
from app.services.provider_capabilities import derive_supports_image_input
cfg_litellm_params = cfg.get("litellm_params") or {}
base_model = (
cfg_litellm_params.get("base_model")
if isinstance(cfg_litellm_params, dict)
else None
)
return derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
"""Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts).
When ``requires_image_input`` is True (image turn), additionally
filters out configs whose ``supports_image_input`` resolves to False
so a text-only deployment can't be pinned for an image request.
"""
candidates = [
cfg
@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
and (not requires_image_input or _cfg_supports_image_input(cfg))
]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
@ -237,11 +272,20 @@ async def resolve_or_get_pinned_llm_config_id(
selected_llm_config_id: int,
force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None,
requires_image_input: bool = False,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
For non-auto selections, this function clears any existing pin and returns
the selected id as-is.
When ``requires_image_input`` is True (the current turn carries an
``image_url`` block), the candidate pool is filtered to vision-capable
cfgs and any existing pin that can't accept image input is treated as
invalid (force re-pin). If no vision-capable cfg is available the
function raises ``ValueError`` so the streaming task surfaces the same
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
silently routing the image to a text-only deployment.
"""
thread = (
(
@ -274,14 +318,24 @@ async def resolve_or_get_pinned_llm_config_id(
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
c
for c in _global_candidates(requires_image_input=requires_image_input)
if int(c.get("id", 0)) not in excluded_ids
]
if not candidates:
if requires_image_input:
# Distinguish the "no vision-capable cfg" case from generic
# "no usable cfg" so the streaming task can map this to the
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
raise ValueError(
"No vision-capable global LLM configs are available for Auto mode"
)
raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent
# tier switch), unless the caller explicitly requests a forced repin to free.
# tier switch), unless the caller explicitly requests a forced repin to free
# *or* the turn requires image input but the pin can't handle it.
pinned_id = thread.pinned_llm_config_id
if (
not force_repin_free
@ -311,6 +365,29 @@ async def resolve_or_get_pinned_llm_config_id(
from_existing_pin=True,
)
if pinned_id is not None:
# If the pin is *only* invalid because it can't handle the image
# turn (it's still a healthy, usable config in the broader pool),
# log that explicitly so operators can correlate the re-pin with
# the user's image attachment instead of suspecting a cooldown.
if requires_image_input:
try:
pinned_global = next(
c
for c in config.GLOBAL_LLM_CONFIGS
if int(c.get("id", 0)) == int(pinned_id)
)
except StopIteration:
pinned_global = None
if pinned_global is not None and not _cfg_supports_image_input(
pinned_global
):
logger.info(
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
"previous_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,
@ -327,6 +404,10 @@ async def resolve_or_get_pinned_llm_config_id(
eligible = [c for c in candidates if _tier_of(c) != "premium"]
if not eligible:
if requires_image_input:
raise ValueError(
"Auto mode could not find a vision-capable LLM config for this user and quota state"
)
raise ValueError(
"Auto mode could not find an eligible LLM config for this user and quota state"
)

View file

@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it.
KEY DESIGN POINTS (issue A, B):
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
insert each run inside their own ``shielded_async_session()``. This
guarantees that a quota commit/rollback can never accidentally flush or
roll back rows the caller has staged in the request's main session
(e.g. a freshly-created ``ImageGeneration`` row).
1. **Session isolation.** ``billable_call`` takes no caller transaction.
All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
inside their own session context. Route callers use
``shielded_async_session()`` by default; Celery callers can provide a
worker-loop-safe session factory. This guarantees that quota
commit/rollback can never accidentally flush or roll back rows the caller
has staged in its main session (e.g. a freshly-created
``ImageGeneration`` row).
2. **ContextVar safety.** The accumulator is scoped via
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B):
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
from typing import Any
from uuid import UUID, uuid4
@ -58,6 +61,12 @@ from app.services.token_tracking_service import (
logger = logging.getLogger(__name__)
AUDIT_TIMEOUT_SECONDS = 10.0
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
{"video_presentation_generation", "podcast_generation"}
)
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
class QuotaInsufficientError(Exception):
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception):
)
class BillingSettlementError(Exception):
"""Raised when a premium call completed but credit settlement failed."""
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
self.usage_type = usage_type
self.user_id = user_id
super().__init__(
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
)
async def _rollback_safely(session: AsyncSession) -> None:
rollback = getattr(session, "rollback", None)
if rollback is not None:
with suppress(Exception):
await rollback()
async def _record_audit_best_effort(
*,
session_factory: BillableSessionFactory,
usage_type: str,
search_space_id: int,
user_id: UUID,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
cost_micros: int,
model_breakdown: dict[str, Any],
call_details: dict[str, Any] | None,
thread_id: int | None,
message_id: int | None,
audit_label: str,
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> None:
"""Persist a TokenUsage row without letting audit failure block callers.
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
audit insert or commit hangs, user-facing artifacts such as videos and
podcasts must still be able to transition to READY after settlement.
"""
audit_thread_id = (
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
)
async def _persist() -> None:
logger.info(
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
"total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
total_tokens,
cost_micros,
)
async with session_factory() as audit_session:
try:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost_micros=cost_micros,
model_breakdown=model_breakdown,
call_details=call_details,
thread_id=audit_thread_id,
message_id=message_id,
)
logger.info(
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
audit_label,
usage_type,
user_id,
audit_thread_id,
)
await audit_session.commit()
logger.info(
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
audit_label,
usage_type,
user_id,
audit_thread_id,
)
except BaseException:
await _rollback_safely(audit_session)
raise
try:
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
except TimeoutError:
logger.warning(
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
"timeout=%.1fs total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
timeout_seconds,
total_tokens,
cost_micros,
)
except Exception:
logger.exception(
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
"total_tokens=%d cost_micros=%d",
audit_label,
usage_type,
user_id,
audit_thread_id,
total_tokens,
cost_micros,
)
@asynccontextmanager
async def billable_call(
*,
@ -101,6 +228,8 @@ async def billable_call(
thread_id: int | None = None,
message_id: int | None = None,
call_details: dict[str, Any] | None = None,
billable_session_factory: BillableSessionFactory | None = None,
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
) -> AsyncIterator[TurnTokenAccumulator]:
"""Wrap a single billable LLM/image call.
@ -124,6 +253,13 @@ async def billable_call(
thread_id, message_id: Optional FK columns on ``TokenUsage``.
call_details: Optional per-call metadata (model name, parameters)
forwarded to ``record_token_usage``.
billable_session_factory: Optional async context factory used for
reserve/finalize/release/audit sessions. Defaults to
``shielded_async_session`` for route callers; Celery callers pass
a worker-loop-safe session factory.
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
Audit failure is best-effort and does not undo successful
settlement.
Yields:
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
@ -134,6 +270,7 @@ async def billable_call(
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
"""
is_premium = billing_tier == "premium"
session_factory = billable_session_factory or shielded_async_session
async with scoped_turn() as acc:
# ---------- Free path: just audit -------------------------------
@ -143,30 +280,22 @@ async def billable_call(
finally:
# Always audit, even on exception, so we capture cost when
# provider returns successfully but the caller raises later.
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=acc.total_cost_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] free-path audit insert failed for "
"usage_type=%s user_id=%s",
usage_type,
user_id,
)
await _record_audit_best_effort(
session_factory=session_factory,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=acc.total_cost_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
audit_label="free",
timeout_seconds=audit_timeout_seconds,
)
return
# ---------- Premium path: reserve → execute → finalize ----------
@ -180,7 +309,7 @@ async def billable_call(
request_id = str(uuid4())
async with shielded_async_session() as quota_session:
async with session_factory() as quota_session:
reserve_result = await TokenQuotaService.premium_reserve(
db_session=quota_session,
user_id=user_id,
@ -222,7 +351,7 @@ async def billable_call(
# from a downstream call, asyncio cancellation, etc.). We use
# BaseException so cancellation also releases.
try:
async with shielded_async_session() as quota_session:
async with session_factory() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
@ -241,7 +370,16 @@ async def billable_call(
# ---------- Success: finalize + audit ----------------------------
actual_micros = acc.total_cost_micros
try:
async with shielded_async_session() as quota_session:
logger.info(
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
"reserved=%d thread=%s",
user_id,
usage_type,
actual_micros,
reserve_micros,
thread_id,
)
async with session_factory() as quota_session:
final_result = await TokenQuotaService.premium_finalize(
db_session=quota_session,
user_id=user_id,
@ -260,7 +398,7 @@ async def billable_call(
final_result.limit,
final_result.remaining,
)
except Exception:
except Exception as finalize_exc:
# Last-ditch: if finalize itself fails, we must at least release
# so the reservation doesn't leak.
logger.exception(
@ -269,7 +407,7 @@ async def billable_call(
user_id,
)
try:
async with shielded_async_session() as quota_session:
async with session_factory() as quota_session:
await TokenQuotaService.premium_release(
db_session=quota_session,
user_id=user_id,
@ -281,31 +419,28 @@ async def billable_call(
"for user=%s",
user_id,
)
raise BillingSettlementError(
usage_type=usage_type,
user_id=user_id,
cause=finalize_exc,
) from finalize_exc
try:
async with shielded_async_session() as audit_session:
await record_token_usage(
audit_session,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=actual_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
)
await audit_session.commit()
except Exception:
logger.exception(
"[billable_call] premium-path audit insert failed for "
"usage_type=%s user_id=%s (debit was applied)",
usage_type,
user_id,
)
await _record_audit_best_effort(
session_factory=session_factory,
usage_type=usage_type,
search_space_id=search_space_id,
user_id=user_id,
prompt_tokens=acc.total_prompt_tokens,
completion_tokens=acc.total_completion_tokens,
total_tokens=acc.grand_total,
cost_micros=actual_micros,
model_breakdown=acc.per_message_summary(),
call_details=call_details,
thread_id=thread_id,
message_id=message_id,
audit_label="premium",
timeout_seconds=audit_timeout_seconds,
)
async def _resolve_agent_billing_for_search_space(
@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space(
__all__ = [
"BillingSettlementError",
"QuotaInsufficientError",
"_resolve_agent_billing_for_search_space",
"billable_call",

View file

@ -20,6 +20,8 @@ from typing import Any
from litellm import Router
from litellm.utils import ImageResponse
from app.services.provider_api_base import resolve_api_base
logger = logging.getLogger(__name__)
# Special ID for Auto mode - uses router for load balancing
@ -152,12 +154,12 @@ class ImageGenRouterService:
return None
# Build model string
provider = config.get("provider", "").upper()
if config.get("custom_provider"):
model_string = f"{config['custom_provider']}/{config['model_name']}"
provider_prefix = config["custom_provider"]
else:
provider = config.get("provider", "").upper()
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{config['model_name']}"
model_string = f"{provider_prefix}/{config['model_name']}"
# Build litellm params
litellm_params: dict[str, Any] = {
@ -165,9 +167,16 @@ class ImageGenRouterService:
"api_key": config.get("api_key"),
}
# Add optional api_base
if config.get("api_base"):
litellm_params["api_base"] = config["api_base"]
# Resolve ``api_base`` so deployments don't silently inherit
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
# the wrong provider (see ``provider_api_base`` docstring).
api_base = resolve_api_base(
provider=provider,
provider_prefix=provider_prefix,
config_api_base=config.get("api_base"),
)
if api_base:
litellm_params["api_base"] = api_base
# Add api_version (required for Azure)
if config.get("api_version"):

View file

@ -140,8 +140,6 @@ PROVIDER_MAP = {
# 404-ing against an inherited Azure endpoint). Re-exported here for
# backward compatibility with any external import.
from app.services.provider_api_base import ( # noqa: E402
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)

View file

@ -16,6 +16,7 @@ from app.services.llm_router_service import (
get_auto_mode_llm,
is_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import token_tracker
# Configure litellm to automatically drop unsupported parameters
@ -556,22 +557,26 @@ async def get_vision_llm(
return None
if global_cfg.get("custom_provider"):
model_string = (
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
)
provider_prefix = global_cfg["custom_provider"]
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
else:
prefix = VISION_PROVIDER_MAP.get(
provider_prefix = VISION_PROVIDER_MAP.get(
global_cfg["provider"].upper(),
global_cfg["provider"].lower(),
)
model_string = f"{prefix}/{global_cfg['model_name']}"
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
litellm_kwargs = {
"model": model_string,
"api_key": global_cfg["api_key"],
}
if global_cfg.get("api_base"):
litellm_kwargs["api_base"] = global_cfg["api_base"]
api_base = resolve_api_base(
provider=global_cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=global_cfg.get("api_base"),
)
if api_base:
litellm_kwargs["api_base"] = api_base
if global_cfg.get("litellm_params"):
litellm_kwargs.update(global_cfg["litellm_params"])
@ -606,20 +611,26 @@ async def get_vision_llm(
return None
if vision_cfg.custom_provider:
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
provider_prefix = vision_cfg.custom_provider
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
else:
prefix = VISION_PROVIDER_MAP.get(
provider_prefix = VISION_PROVIDER_MAP.get(
vision_cfg.provider.value.upper(),
vision_cfg.provider.value.lower(),
)
model_string = f"{prefix}/{vision_cfg.model_name}"
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
litellm_kwargs = {
"model": model_string,
"api_key": vision_cfg.api_key,
}
if vision_cfg.api_base:
litellm_kwargs["api_base"] = vision_cfg.api_base
api_base = resolve_api_base(
provider=vision_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=vision_cfg.api_base,
)
if api_base:
litellm_kwargs["api_base"] = api_base
if vision_cfg.litellm_params:
litellm_kwargs.update(vision_cfg.litellm_params)

View file

@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool:
return "image" in input_mods and "text" in output_mods
def _supports_image_input(model: dict) -> bool:
"""Return True if the model accepts ``image`` in its input modalities.
Differs from :func:`_is_vision_input_model` in that it does NOT
require text output chat-tab models always emit text already (the
chat catalog filters by ``_is_text_output_model``), so the only
extra capability we need to track per chat config is whether the
model can ingest user-attached images. The chat selector and the
streaming task both key off this flag to prevent hitting an
OpenRouter 404 ``"No endpoints found that support image input"``
when the user uploads an image and selects a text-only model
(DeepSeek V3, Llama 3.x base, etc.).
"""
arch = model.get("architecture", {}) or {}
input_mods = arch.get("input_modalities", []) or []
return "image" in input_mods
def _supports_tool_calling(model: dict) -> bool:
"""Return True if the model supports function/tool calling."""
supported = model.get("supported_parameters") or []
@ -321,6 +339,13 @@ def _generate_configs(
# account-wide quota, so per-deployment routing can't spread load
# there — it just drains the shared bucket faster.
"router_pool_eligible": tier == "premium",
# Capability flag derived from ``architecture.input_modalities``.
# Read by the new-chat selector to dim image-incompatible models
# when the user has pending image attachments, and by
# ``stream_new_chat`` as a fail-fast safety net before the
# OpenRouter request would otherwise 404 with
# ``"No endpoints found that support image input"``.
"supports_image_input": _supports_image_input(model),
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
@ -398,7 +423,12 @@ def _generate_image_gen_configs(
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
# Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
# ``image_generation/transformation`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"litellm_params": dict(litellm_params),
@ -477,7 +507,11 @@ def _generate_vision_llm_configs(
"provider": "OPENROUTER",
"model_name": model_id,
"api_key": api_key,
"api_base": "",
# Pin to OpenRouter's public base URL so a downstream call site
# that forgets ``resolve_api_base`` still doesn't inherit
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
# ``provider_api_base`` docstring).
"api_base": "https://openrouter.ai/api/v1",
"api_version": None,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,

View file

@ -17,7 +17,6 @@ source of truth without an inter-service circular import.
from __future__ import annotations
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
"openrouter": "https://openrouter.ai/api/v1",
"groq": "https://api.groq.com/openai/v1",

View file

@ -0,0 +1,280 @@
"""Capability resolution shared by chat / image / vision call sites.
Why this exists
---------------
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
single, authoritative answer to one question: *can this chat config accept
``image_url`` content blocks?* Without it, the new-chat selector can't badge
incompatible models and the streaming task can't fail fast with a friendly
error before sending an image to a text-only provider.
Two functions, two intents:
- :func:`derive_supports_image_input` best-effort *True* for catalog and
UI surfacing. Default-allow: an unknown / unmapped model is treated as
capable so we never lock the user out of a freshly added or
third-party-hosted vision model.
- :func:`is_known_text_only_chat_model` strict opt-out for the streaming
task's safety net. Returns True only when LiteLLM's model map *explicitly*
sets ``supports_vision=False`` (or its bare-name variant does). Anything
else missing key, lookup exception, ``supports_vision=True`` returns
False so the request flows through to the provider.
Implementation rule: only public LiteLLM symbols
------------------------------------------------
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
across releases. The private ``_is_explicitly_disabled_factory`` and
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
can't silently break us.
Why the previous round's strict YAML opt-in flag failed
-------------------------------------------------------
``supports_image_input: false`` was the YAML loader's setdefault. Operators
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
YAML chat model including vision-capable GPT-5.x and GPT-4o resolved to
False and the streaming gate rejected every image turn. Sourcing capability
from LiteLLM's authoritative model map (which already says
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
"""
from __future__ import annotations
import logging
from collections.abc import Iterable
import litellm
logger = logging.getLogger(__name__)
# Provider-name → LiteLLM model-prefix map.
#
# Owned here because ``app.services.provider_capabilities`` is the
# only edge that's safe to call from ``app.config``'s YAML loader at
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
# this constant under the historical ``PROVIDER_MAP`` name; placing the
# map there directly would re-introduce the
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
# app.config`` cycle that prompted the move.
_PROVIDER_PREFIX_MAP: dict[str, str] = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"GITHUB_MODELS": "github",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"MINIMAX": "openai",
"CUSTOM": "custom",
}
def _candidate_model_strings(
*,
provider: str | None,
model_name: str | None,
base_model: str | None,
custom_provider: str | None,
) -> list[tuple[str, str | None]]:
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
LiteLLM's capability lookup is keyed by ``model`` + (optional)
``custom_llm_provider``. Different config sources give us different
levels of detail, so we try the most-specific keys first and fall back
to bare model names so unannotated entries (e.g. an Azure deployment
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
map. Order matters the first lookup that returns a definitive answer
wins for both helpers.
"""
candidates: list[tuple[str, str | None]] = []
seen: set[tuple[str, str | None]] = set()
def _add(model: str | None, llm_provider: str | None) -> None:
if not model:
return
key = (model, llm_provider)
if key in seen:
return
seen.add(key)
candidates.append(key)
provider_prefix: str | None = None
if provider:
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
if custom_provider:
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
provider_prefix = custom_provider
primary_model = base_model or model_name
bare_model = model_name
# Most-specific first: provider-prefixed identifier with explicit
# custom_llm_provider so LiteLLM won't have to guess the provider via
# ``get_llm_provider``.
if primary_model and provider_prefix:
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
if "/" in primary_model:
_add(primary_model, provider_prefix)
else:
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
# Bare base_model (or model_name) with provider hint — handles entries
# the upstream map keys without a provider prefix (most ``gpt-*`` and
# ``claude-*`` entries do this).
if primary_model:
_add(primary_model, provider_prefix)
# Fallback to model_name when base_model differs (e.g. an Azure
# deployment whose model_name is the deployment id but base_model is the
# canonical OpenAI sku).
if bare_model and bare_model != primary_model:
if provider_prefix and "/" not in bare_model:
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
_add(bare_model, provider_prefix)
_add(bare_model, None)
return candidates
def derive_supports_image_input(
*,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
openrouter_input_modalities: Iterable[str] | None = None,
) -> bool:
"""Best-effort capability flag for the new-chat selector and catalog.
Resolution order (first definitive answer wins):
1. ``openrouter_input_modalities`` (when provided as a non-empty
iterable). OpenRouter exposes ``architecture.input_modalities`` per
model and that's the authoritative source for OR dynamic configs.
2. ``litellm.supports_vision`` against each candidate identifier from
:func:`_candidate_model_strings`. Returns True as soon as any
candidate confirms vision support.
3. Default ``True`` the conservative-allow stance. An unknown /
newly-added / third-party-hosted model is *not* pre-judged. The
streaming safety net (:func:`is_known_text_only_chat_model`) is the
only place a False ever blocks; everywhere else, a False here would
just hide a usable model from the user.
Returns:
True if the model can plausibly accept image input, False only when
OpenRouter explicitly says it can't.
"""
if openrouter_input_modalities is not None:
modalities = list(openrouter_input_modalities)
if modalities:
return "image" in modalities
# Empty list explicitly published by OR — treat as "no image".
return False
for model_string, custom_llm_provider in _candidate_model_strings(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
):
try:
if litellm.supports_vision(
model=model_string, custom_llm_provider=custom_llm_provider
):
return True
except Exception as exc:
logger.debug(
"litellm.supports_vision raised for model=%s provider=%s: %s",
model_string,
custom_llm_provider,
exc,
)
continue
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
return True
def is_known_text_only_chat_model(
*,
provider: str | None = None,
model_name: str | None = None,
base_model: str | None = None,
custom_provider: str | None = None,
) -> bool:
"""Strict opt-out probe for the streaming-task safety net.
Returns True only when LiteLLM's model map *explicitly* sets
``supports_vision=False`` for at least one candidate identifier. Missing
key, lookup exception, or ``supports_vision=True`` all return False so
the streaming task lets the request through. This is the inverse-default
of :func:`derive_supports_image_input`.
Why two functions
-----------------
The selector wants "show me everything that's plausibly capable"
default-allow. The safety net wants "block only when I'm certain it
can't" — default-pass. Mixing the two intents in a single function
leads to the regression we're fixing here.
"""
for model_string, custom_llm_provider in _candidate_model_strings(
provider=provider,
model_name=model_name,
base_model=base_model,
custom_provider=custom_provider,
):
try:
info = litellm.get_model_info(
model=model_string, custom_llm_provider=custom_llm_provider
)
except Exception as exc:
logger.debug(
"litellm.get_model_info raised for model=%s provider=%s: %s",
model_string,
custom_llm_provider,
exc,
)
continue
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
# may be missing, None, True, or False. We only fire on explicit
# False — None / missing / True all mean "don't block".
try:
value = info.get("supports_vision") # type: ignore[union-attr]
except AttributeError:
value = None
if value is False:
return True
return False
__all__ = [
"derive_supports_image_input",
"is_known_text_only_chat_model",
]

View file

@ -1,10 +1,25 @@
"""Celery tasks package."""
"""Celery tasks package.
Also hosts the small helpers every async celery task should use to
spin up its event loop. See :func:`run_async_celery_task` for the
canonical pattern.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import Awaitable, Callable
from typing import TypeVar
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.config import config
logger = logging.getLogger(__name__)
_celery_engine = None
_celery_session_maker = None
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
_celery_engine, expire_on_commit=False
)
return _celery_session_maker
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
"""Drop the shared ``app.db.engine`` connection pool synchronously.
The shared engine (used by ``shielded_async_session`` and most
routes / services) is a module-level singleton with a real pool.
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
connections cache a reference to whichever loop opened them. When
a subsequent task's loop pulls a stale connection from the pool,
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
AttributeError: 'NoneType' object has no attribute 'send'
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
self._write_fut = self._loop._proactor.send(self._sock, data)
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
coroutine that can never run because its loop is gone.
Disposing the engine forces the pool to drop every cached
connection so the next checkout opens a fresh one on the current
loop. Safe to call from a task's finally block; failure is logged
but never propagated.
"""
try:
from app.db import engine as shared_engine
loop.run_until_complete(shared_engine.dispose())
except Exception:
logger.warning("Shared DB engine dispose() failed", exc_info=True)
T = TypeVar("T")
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
"""Run an async coroutine inside a fresh event loop with proper
DB-engine cleanup.
This is the canonical entry point for every async celery task.
It performs three responsibilities that were previously copy-pasted
(incorrectly) across each task module:
1. Create a fresh ``asyncio`` loop and install it on the current
thread (celery's ``--pool=solo`` runs every task on the main
thread, but other pool types don't).
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
any stale connections left over from a previous task's loop
are dropped defends against tasks that crashed without
cleaning up.
3. Dispose the shared engine AFTER the task runs so the
connections we opened on this loop are released before the
loop closes (avoids ``coroutine 'Connection._cancel' was
never awaited`` warnings and the next-task hang).
Use as::
@celery_app.task(name="my_task", bind=True)
def my_task(self, *args):
return run_async_celery_task(lambda: _my_task_impl(*args))
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Defense-in-depth: prior task may have crashed before
# disposing. Idempotent — no-op if pool is already empty.
_dispose_shared_db_engine(loop)
return loop.run_until_complete(coro_factory())
finally:
# Drop any connections this task opened so they don't leak
# into the next task's loop.
_dispose_shared_db_engine(loop)
with contextlib.suppress(Exception):
loop.run_until_complete(loop.shutdown_asyncgens())
with contextlib.suppress(Exception):
asyncio.set_event_loop(None)
loop.close()
__all__ = [
"get_celery_session_maker",
"run_async_celery_task",
]

View file

@ -4,7 +4,7 @@ import logging
import traceback
from app.celery_app import celery_app
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -49,22 +49,15 @@ def index_notion_pages_task(
end_date: str,
):
"""Celery task to index Notion pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_notion_pages(
return run_async_celery_task(
lambda: _index_notion_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_notion_pages", connector_id)
raise
finally:
loop.close()
async def _index_notion_pages(
@ -95,19 +88,11 @@ def index_github_repos_task(
end_date: str,
):
"""Celery task to index GitHub repositories."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_github_repos(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_github_repos(
@ -138,19 +123,11 @@ def index_confluence_pages_task(
end_date: str,
):
"""Celery task to index Confluence pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_confluence_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_confluence_pages(
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
end_date: str,
):
"""Celery task to index Google Calendar events."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_calendar_events(
return run_async_celery_task(
lambda: _index_google_calendar_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
raise
finally:
loop.close()
async def _index_google_calendar_events(
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
end_date: str,
):
"""Celery task to index Google Gmail messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_google_gmail_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_google_gmail_messages(
@ -269,22 +231,14 @@ def index_google_drive_files_task(
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
):
"""Celery task to index Google Drive folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_google_drive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_google_drive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_google_drive_files(
@ -317,22 +271,14 @@ def index_onedrive_files_task(
items_dict: dict,
):
"""Celery task to index OneDrive folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_onedrive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_onedrive_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_onedrive_files(
@ -365,22 +311,14 @@ def index_dropbox_files_task(
items_dict: dict,
):
"""Celery task to index Dropbox folders and files."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_dropbox_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
return run_async_celery_task(
lambda: _index_dropbox_files(
connector_id,
search_space_id,
user_id,
items_dict,
)
finally:
loop.close()
)
async def _index_dropbox_files(
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
end_date: str,
):
"""Celery task to index Elasticsearch documents."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_elasticsearch_documents(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_elasticsearch_documents(
@ -457,22 +387,15 @@ def index_crawled_urls_task(
end_date: str,
):
"""Celery task to index Web page Urls."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_crawled_urls(
return run_async_celery_task(
lambda: _index_crawled_urls(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
raise
finally:
loop.close()
async def _index_crawled_urls(
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
end_date: str,
):
"""Celery task to index BookStack pages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_bookstack_pages(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_bookstack_pages(
@ -546,19 +461,11 @@ def index_composio_connector_task(
end_date: str | None,
):
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date
)
return run_async_celery_task(
lambda: _index_composio_connector(
connector_id, search_space_id, user_id, start_date, end_date
)
finally:
loop.close()
)
async def _index_composio_connector(

View file

@ -11,7 +11,7 @@ from app.db import Document
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
document_id: ID of document to reindex
user_id: ID of user who edited the document
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reindex_document(document_id, user_id))
finally:
loop.close()
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
async def _reindex_document(document_id: int, user_id: str):

View file

@ -11,7 +11,7 @@ from app.celery_app import celery_app
from app.config import config
from app.services.notification_service import NotificationService
from app.services.task_logging_service import TaskLoggingService
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.tasks.connector_indexers.local_folder_indexer import (
index_local_folder,
index_uploaded_files,
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
)
def delete_document_task(self, document_id: int):
"""Celery task to delete a document and its chunks in batches."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_document_background(document_id))
finally:
loop.close()
return run_async_celery_task(lambda: _delete_document_background(document_id))
async def _delete_document_background(document_id: int) -> None:
@ -153,14 +148,9 @@ def delete_folder_documents_task(
folder_subtree_ids: list[int] | None = None,
):
"""Celery task to delete documents first, then the folder rows."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_delete_folder_documents(document_ids, folder_subtree_ids)
)
finally:
loop.close()
return run_async_celery_task(
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
)
async def _delete_folder_documents(
@ -209,12 +199,9 @@ async def _delete_folder_documents(
)
def delete_search_space_task(self, search_space_id: int):
"""Celery task to delete a search space and heavy child rows in batches."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_delete_search_space_background(search_space_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _delete_search_space_background(search_space_id)
)
async def _delete_search_space_background(search_space_id: int) -> None:
@ -269,18 +256,11 @@ def process_extension_document_task(
search_space_id: ID of the search space
user_id: ID of the user
"""
# Create a new event loop for this task
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_extension_document(
individual_document_dict, search_space_id, user_id
)
return run_async_celery_task(
lambda: _process_extension_document(
individual_document_dict, search_space_id, user_id
)
finally:
loop.close()
)
async def _process_extension_document(
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
search_space_id: ID of the search space
user_id: ID of the user
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _process_youtube_video(url, search_space_id, user_id)
)
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
@ -573,12 +549,9 @@ def process_file_upload_task(
except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_file_upload(file_path, filename, search_space_id, user_id)
run_async_celery_task(
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
)
logger.info(
f"[process_file_upload] Task completed successfully for: {filename}"
@ -589,8 +562,6 @@ def process_file_upload_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally:
loop.close()
async def _process_file_upload(
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
"File may have been removed before syncing could start."
)
# Mark document as failed since file is missing
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_mark_document_failed(
document_id,
"File not found. Please re-upload the file.",
)
run_async_celery_task(
lambda: _mark_document_failed(
document_id,
"File not found. Please re-upload the file.",
)
finally:
loop.close()
)
return
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_file_with_document(
run_async_celery_task(
lambda: _process_file_with_document(
document_id,
temp_path,
filename,
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally:
loop.close()
async def _mark_document_failed(document_id: int, reason: str):
@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_process_circleback_meeting(
meeting_id,
meeting_name,
markdown_content,
metadata,
search_space_id,
connector_id,
)
return run_async_celery_task(
lambda: _process_circleback_meeting(
meeting_id,
meeting_name,
markdown_content,
metadata,
search_space_id,
connector_id,
)
finally:
loop.close()
)
async def _process_circleback_meeting(
@ -1291,25 +1246,19 @@ def index_local_folder_task(
target_file_paths: list[str] | None = None,
):
"""Celery task to index a local folder. Config is passed directly — no connector row."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_local_folder_async(
search_space_id=search_space_id,
user_id=user_id,
folder_path=folder_path,
folder_name=folder_name,
exclude_patterns=exclude_patterns,
file_extensions=file_extensions,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
target_file_paths=target_file_paths,
)
return run_async_celery_task(
lambda: _index_local_folder_async(
search_space_id=search_space_id,
user_id=user_id,
folder_path=folder_path,
folder_name=folder_name,
exclude_patterns=exclude_patterns,
file_extensions=file_extensions,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
target_file_paths=target_file_paths,
)
finally:
loop.close()
)
async def _index_local_folder_async(
@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
processing_mode: str = "basic",
):
"""Celery task to index files uploaded from the desktop app."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_uploaded_folder_files_async(
search_space_id=search_space_id,
user_id=user_id,
folder_name=folder_name,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
return run_async_celery_task(
lambda: _index_uploaded_folder_files_async(
search_space_id=search_space_id,
user_id=user_id,
folder_name=folder_name,
root_folder_id=root_folder_id,
enable_summary=enable_summary,
file_mappings=file_mappings,
use_vision_llm=use_vision_llm,
processing_mode=processing_mode,
)
finally:
loop.close()
)
async def _index_uploaded_folder_files_async(
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
finally:
loop.close()
return run_async_celery_task(
lambda: _ai_sort_search_space_async(search_space_id, user_id)
)
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_ai_sort_document_async(search_space_id, user_id, document_id)
)
finally:
loop.close()
return run_async_celery_task(
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
)
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):

View file

@ -2,14 +2,13 @@
from __future__ import annotations
import asyncio
import logging
from app.celery_app import celery_app
from app.db import SearchSourceConnector
from app.schemas.obsidian_plugin import NotePayload
from app.services.obsidian_plugin_indexer import upsert_note
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
user_id: str,
) -> None:
"""Process one Obsidian non-markdown attachment asynchronously."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_obsidian_attachment(
connector_id=connector_id,
payload_data=payload_data,
user_id=user_id,
)
return run_async_celery_task(
lambda: _index_obsidian_attachment(
connector_id=connector_id,
payload_data=payload_data,
user_id=user_id,
)
finally:
loop.close()
)
async def _index_obsidian_attachment(

View file

@ -3,6 +3,7 @@
import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from sqlalchemy import select
@ -12,11 +13,12 @@ from app.celery_app import celery_app
from app.config import config as app_config
from app.db import Podcast, PodcastStatus
from app.services.billable_calls import (
BillingSettlementError,
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -34,6 +36,13 @@ if sys.platform.startswith("win"):
# =============================================================================
@asynccontextmanager
async def _celery_billable_session():
"""Session factory used by billable_call inside the Celery worker loop."""
async with get_celery_session_maker()() as session:
yield session
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
@ -46,27 +55,22 @@ def generate_content_podcast_task(
Celery task to generate podcast from source content.
Updates existing podcast record created by the tool.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_content_podcast(
return run_async_celery_task(
lambda: _generate_content_podcast(
podcast_id,
source_content,
search_space_id,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
loop.run_until_complete(_mark_podcast_failed(podcast_id))
try:
run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
except Exception:
logger.exception("Failed to mark podcast %s as failed", podcast_id)
return {"status": "failed", "podcast_id": podcast_id}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _mark_podcast_failed(podcast_id: int) -> None:
@ -148,11 +152,12 @@ async def _generate_content_podcast(
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
usage_type="podcast_generation",
thread_id=podcast.thread_id,
call_details={
"podcast_id": podcast.id,
"title": podcast.title,
"thread_id": podcast.thread_id,
},
billable_session_factory=_celery_billable_session,
):
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
@ -173,6 +178,18 @@ async def _generate_content_podcast(
"podcast_id": podcast.id,
"reason": "premium_quota_exhausted",
}
except BillingSettlementError:
logger.exception(
"Podcast %s: premium billing settlement failed",
podcast.id,
)
podcast.status = PodcastStatus.FAILED
await session.commit()
return {
"status": "failed",
"podcast_id": podcast.id,
"reason": "billing_settlement_failed",
}
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
@ -194,7 +211,14 @@ async def _generate_content_podcast(
podcast.podcast_transcript = serializable_transcript
podcast.file_location = file_path
podcast.status = PodcastStatus.READY
logger.info(
"Podcast %s: committing READY transcript_entries=%d file=%s",
podcast.id,
len(serializable_transcript),
file_path,
)
await session.commit()
logger.info("Podcast %s: READY commit complete", podcast.id)
logger.info(f"Successfully generated podcast: {podcast.id}")

View file

@ -7,7 +7,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
from app.utils.indexing_locks import is_connector_indexing_locked
logger = logging.getLogger(__name__)
@ -20,15 +20,7 @@ def check_periodic_schedules_task():
This task runs every minute and triggers indexing for any connector
whose next_scheduled_at time has passed.
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_check_and_trigger_schedules())
finally:
loop.close()
return run_async_celery_task(_check_and_trigger_schedules)
async def _check_and_trigger_schedules():

View file

@ -34,7 +34,7 @@ from sqlalchemy.future import select
from app.celery_app import celery_app
from app.config import config
from app.db import Document, DocumentStatus, Notification
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
Also marks associated pending/processing documents as failed.
"""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def _both() -> None:
await _cleanup_stale_notifications()
await _cleanup_stale_document_processing_notifications()
try:
loop.run_until_complete(_cleanup_stale_notifications())
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
finally:
loop.close()
return run_async_celery_task(_both)
async def _cleanup_stale_notifications():

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime, timedelta
@ -18,7 +17,7 @@ from app.db import (
PremiumTokenPurchaseStatus,
)
from app.routes import stripe_routes
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
def reconcile_pending_stripe_page_purchases_task():
"""Recover paid purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_page_purchases())
finally:
loop.close()
return run_async_celery_task(_reconcile_pending_page_purchases)
async def _reconcile_pending_page_purchases() -> None:
@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
def reconcile_pending_stripe_token_purchases_task():
"""Recover paid token purchases that were left pending due to missed webhook handling."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_reconcile_pending_token_purchases())
finally:
loop.close()
return run_async_celery_task(_reconcile_pending_token_purchases)
async def _reconcile_pending_token_purchases() -> None:

View file

@ -3,6 +3,7 @@
import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from sqlalchemy import select
@ -12,11 +13,12 @@ from app.celery_app import celery_app
from app.config import config as app_config
from app.db import VideoPresentation, VideoPresentationStatus
from app.services.billable_calls import (
BillingSettlementError,
QuotaInsufficientError,
_resolve_agent_billing_for_search_space,
billable_call,
)
from app.tasks.celery_tasks import get_celery_session_maker
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
logger = logging.getLogger(__name__)
@ -29,6 +31,13 @@ if sys.platform.startswith("win"):
)
@asynccontextmanager
async def _celery_billable_session():
"""Session factory used by billable_call inside the Celery worker loop."""
async with get_celery_session_maker()() as session:
yield session
@celery_app.task(name="generate_video_presentation", bind=True)
def generate_video_presentation_task(
self,
@ -41,27 +50,30 @@ def generate_video_presentation_task(
Celery task to generate video presentation from source content.
Updates existing video presentation record created by the tool.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_video_presentation(
return run_async_celery_task(
lambda: _generate_video_presentation(
video_presentation_id,
source_content,
search_space_id,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating video presentation: {e!s}")
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
# Mark FAILED in a fresh loop — the previous loop is closed.
# Swallow secondary failures; the row will simply stay in
# GENERATING and be flushed by the periodic stale cleanup.
try:
run_async_celery_task(
lambda: _mark_video_presentation_failed(video_presentation_id)
)
except Exception:
logger.exception(
"Failed to mark video presentation %s as failed",
video_presentation_id,
)
return {"status": "failed", "video_presentation_id": video_presentation_id}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
@ -150,11 +162,12 @@ async def _generate_video_presentation(
base_model=base_model,
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
usage_type="video_presentation_generation",
thread_id=video_pres.thread_id,
call_details={
"video_presentation_id": video_pres.id,
"title": video_pres.title,
"thread_id": video_pres.thread_id,
},
billable_session_factory=_celery_billable_session,
):
graph_result = await video_presentation_graph.ainvoke(
initial_state, config=graph_config
@ -175,6 +188,18 @@ async def _generate_video_presentation(
"video_presentation_id": video_pres.id,
"reason": "premium_quota_exhausted",
}
except BillingSettlementError:
logger.exception(
"VideoPresentation %s: premium billing settlement failed",
video_pres.id,
)
video_pres.status = VideoPresentationStatus.FAILED
await session.commit()
return {
"status": "failed",
"video_presentation_id": video_pres.id,
"reason": "billing_settlement_failed",
}
# Serialize slides (parsed content + audio info merged)
slides_raw = graph_result.get("slides", [])
@ -205,7 +230,14 @@ async def _generate_video_presentation(
video_pres.slides = serializable_slides
video_pres.scene_codes = serializable_scene_codes
video_pres.status = VideoPresentationStatus.READY
logger.info(
"VideoPresentation %s: committing READY slides=%d scene_codes=%d",
video_pres.id,
len(serializable_slides),
len(serializable_scene_codes),
)
await session.commit()
logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
logger.info(f"Successfully generated video presentation: {video_pres.id}")

View file

@ -1506,10 +1506,10 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else "Podcast"
)
if podcast_status == "processing":
if podcast_status in ("pending", "generating", "processing"):
completed_items = [
f"Title: {podcast_title}",
"Audio generation started",
"Podcast generation started",
"Processing in background...",
]
elif podcast_status == "already_generating":
@ -1518,7 +1518,7 @@ async def _stream_agent_events(
"Podcast already in progress",
"Please wait for it to complete",
]
elif podcast_status == "error":
elif podcast_status in ("failed", "error"):
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
@ -1528,6 +1528,11 @@ async def _stream_agent_events(
f"Title: {podcast_title}",
f"Error: {error_msg[:50]}",
]
elif podcast_status in ("ready", "success"):
completed_items = [
f"Title: {podcast_title}",
"Podcast ready",
]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
@ -1710,20 +1715,28 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else {"result": tool_output},
)
if (
isinstance(tool_output, dict)
and tool_output.get("status") == "success"
if isinstance(tool_output, dict) and tool_output.get("status") in (
"pending",
"generating",
"processing",
):
yield streaming_service.format_terminal_info(
f"Podcast queued: {tool_output.get('title', 'Podcast')}",
"success",
)
elif isinstance(tool_output, dict) and tool_output.get("status") in (
"ready",
"success",
):
yield streaming_service.format_terminal_info(
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
"success",
)
else:
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
else "Unknown error"
)
elif isinstance(tool_output, dict) and tool_output.get("status") in (
"failed",
"error",
):
error_msg = tool_output.get("error", "Unknown error")
yield streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}",
"error",
@ -2292,6 +2305,11 @@ async def stream_new_chat(
)
_t0 = time.perf_counter()
# Image-bearing turns force the Auto-pin resolver to filter the
# candidate pool to vision-capable cfgs (and force-repin a
# text-only existing pin). For explicit selections this flag is
# a no-op — the resolver returns the user's chosen id unchanged.
_requires_image_input = bool(user_image_data_urls)
try:
llm_config_id = (
await resolve_or_get_pinned_llm_config_id(
@ -2300,13 +2318,29 @@ async def stream_new_chat(
search_space_id=search_space_id,
user_id=user_id,
selected_llm_config_id=llm_config_id,
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
# Auto-pin's "no vision-capable cfg" path raises a ValueError
# whose message we map to the friendly image-input SSE error
# so the user sees the same message regardless of whether
# the gate fired in Auto-mode or in the agent_config check
# below.
error_code = (
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
if _requires_image_input and "vision-capable" in str(pin_error)
else "SERVER_ERROR"
)
error_kind = (
"user_error"
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
else "server_error"
)
yield _emit_stream_error(
message=str(pin_error),
error_kind="server_error",
error_code="SERVER_ERROR",
error_kind=error_kind,
error_code=error_code,
)
yield streaming_service.format_done()
return
@ -2326,6 +2360,50 @@ async def stream_new_chat(
llm_config_id,
)
# Capability safety net: a turn carrying user-uploaded images
# cannot be routed to a chat config that LiteLLM's authoritative
# model map *explicitly* marks as text-only (``supports_vision``
# set to False). The check is intentionally narrow — it only
# fires when LiteLLM is *certain* the model can't accept image
# input. Unknown / unmapped / vision-capable models pass
# through. Without this guard a known-text-only model would 404
# at the provider with ``"No endpoints found that support image
# input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
# failing here lets us return a friendly message that tells the
# user what to change.
if user_image_data_urls and agent_config is not None:
from app.services.provider_capabilities import (
is_known_text_only_chat_model,
)
agent_litellm_params = agent_config.litellm_params or {}
agent_base_model = (
agent_litellm_params.get("base_model")
if isinstance(agent_litellm_params, dict)
else None
)
if is_known_text_only_chat_model(
provider=agent_config.provider,
model_name=agent_config.model_name,
base_model=agent_base_model,
custom_provider=agent_config.custom_provider,
):
model_label = (
agent_config.config_name or agent_config.model_name or "model"
)
yield _emit_stream_error(
message=(
f"The selected model ({model_label}) does not support "
"image input. Switch to a vision-capable model "
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
"attachment and try again."
),
error_kind="user_error",
error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
)
yield streaming_service.format_done()
return
# Premium quota reservation for pinned premium model only.
_needs_premium_quota = (
agent_config is not None and user_id and agent_config.is_premium
@ -2366,6 +2444,7 @@ async def stream_new_chat(
user_id=user_id,
selected_llm_config_id=0,
force_repin_free=True,
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
@ -2470,6 +2549,7 @@ async def stream_new_chat(
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id
except ValueError as pin_error:
@ -2804,6 +2884,7 @@ async def stream_new_chat(
from litellm import acompletion
from app.services.llm_router_service import LLMRouterService
from app.services.provider_api_base import resolve_api_base
from app.services.token_tracking_service import _turn_accumulator
_turn_accumulator.set(None)
@ -2824,11 +2905,32 @@ async def stream_new_chat(
model="auto", messages=messages
)
else:
# Apply the same ``api_base`` cascade chat / vision /
# image-gen call sites use so we never inherit
# ``litellm.api_base`` (commonly set by
# ``AZURE_OPENAI_ENDPOINT``) when the chat config
# itself ships an empty ``api_base``. Without this
# the title-gen on an OpenRouter chat config would
# 404 against the inherited Azure endpoint — see
# ``provider_api_base`` docstring for the same
# bug repro on the image-gen / vision paths.
raw_model = getattr(llm, "model", "") or ""
provider_prefix = (
raw_model.split("/", 1)[0] if "/" in raw_model else None
)
provider_value = (
agent_config.provider if agent_config is not None else None
)
title_api_base = resolve_api_base(
provider=provider_value,
provider_prefix=provider_prefix,
config_api_base=getattr(llm, "api_base", None),
)
response = await acompletion(
model=llm.model,
model=raw_model,
messages=messages,
api_key=getattr(llm, "api_key", None),
api_base=getattr(llm, "api_base", None),
api_base=title_api_base,
)
usage_info = None
@ -2953,6 +3055,7 @@ async def stream_new_chat(
user_id=user_id,
selected_llm_config_id=0,
exclude_config_ids={previous_config_id},
requires_image_input=_requires_image_input,
)
).resolved_llm_config_id

View file

@ -0,0 +1,558 @@
"""End-to-end smoke test for vision / image config wiring.
Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and
exercises every chat / vision / image-generation config + the OpenRouter
dynamic catalog. For each config the script:
1. Reports the resolver classification (catalog-allow vs strict-block).
2. Optionally fires a tiny live API call against the provider:
- Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt
``"reply with one word: ok"``.
- Vision configs: same, against the dedicated vision router pool.
- Image-gen configs: ``litellm.aimage_generation`` with a single tiny
prompt and ``n=1``.
- OpenRouter integration: samples one chat, one vision, one image-gen
model from the dynamically fetched catalog.
Usage::
python -m scripts.verify_chat_image_capability # capability + connectivity
python -m scripts.verify_chat_image_capability --no-live # capability resolver only
The script is meant to be runnable from the repository root or from
``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the
end so it's usable as a CI smoke check too.
Live-mode caveat: each successful call costs a small amount of provider
credit (a few tokens or one tiny generated image per config). The
default size for image generation is ``1024x1024`` because Azure
GPT-image deployments reject smaller sizes; OpenRouter image-gen models
generally accept the same size.
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Any
# Bootstrap the surfsense_backend package on sys.path so the script runs
# from the repo root or from `surfsense_backend/` interchangeably.
_HERE = os.path.dirname(os.path.abspath(__file__))
_BACKEND_ROOT = os.path.dirname(_HERE)
if _BACKEND_ROOT not in sys.path:
sys.path.insert(0, _BACKEND_ROOT)
import litellm # noqa: E402
from app.config import config # noqa: E402
from app.services.openrouter_integration_service import ( # noqa: E402
_OPENROUTER_DYNAMIC_MARKER,
OpenRouterIntegrationService,
)
from app.services.provider_api_base import resolve_api_base # noqa: E402
from app.services.provider_capabilities import ( # noqa: E402
derive_supports_image_input,
is_known_text_only_chat_model,
)
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
# Quiet down LiteLLM's verbose router/cost logs so the script output is
# scannable.
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
logging.getLogger("litellm").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
# 1x1 transparent PNG — used as the cheapest possible vision payload.
_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}"
# ---------------------------------------------------------------------------
# Result accounting
# ---------------------------------------------------------------------------
@dataclass
class ProbeResult:
label: str
surface: str
config_id: int | str
capability_ok: bool | None = None
capability_note: str = ""
live_ok: bool | None = None
live_note: str = ""
duration_s: float = 0.0
@dataclass
class Report:
results: list[ProbeResult] = field(default_factory=list)
def add(self, r: ProbeResult) -> None:
self.results.append(r)
def render(self) -> int:
passed = failed = skipped = 0
print()
print("=" * 92)
print(
f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes"
)
print("-" * 92)
for r in self.results:
def _flag(value: bool | None) -> str:
if value is None:
return "skip"
return "ok" if value else "fail"
cap = _flag(r.capability_ok)
live = _flag(r.live_ok)
if r.capability_ok is False or r.live_ok is False:
failed += 1
elif r.capability_ok is None and r.live_ok is None:
skipped += 1
else:
passed += 1
print(
f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} "
f"{r.duration_s:>5.2f}s {r.label}"
)
if r.capability_note:
print(f" cap: {r.capability_note}")
if r.live_note:
print(f" live: {r.live_note}")
print("-" * 92)
print(
f"Total: {passed} ok / {failed} fail / {skipped} skip "
f"(of {len(self.results)} probes)"
)
print("=" * 92)
return failed
# ---------------------------------------------------------------------------
# Capability probes (no network)
# ---------------------------------------------------------------------------
def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
"""For chat configs the catalog flag is *expected* True (vision-capable
pool). The probe reports both the resolver value and the strict
safety-net value to surface any drift between them."""
litellm_params = cfg.get("litellm_params") or {}
base_model = (
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
)
cap = derive_supports_image_input(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
block = is_known_text_only_chat_model(
provider=cfg.get("provider"),
model_name=cfg.get("model_name"),
base_model=base_model,
custom_provider=cfg.get("custom_provider"),
)
note = f"derive={cap} strict_block={block}"
if not cap and not block:
# Resolver said False but strict gate is also False — that means
# OR modalities published [text] explicitly. Surface it.
note += " (OR modality says text-only)"
# We accept a True derive *or* (False derive AND False block) as
# 'capability ok' — either way, the streaming task will flow through.
ok = cap or not block
return ok, note
def _build_chat_model_string(cfg: dict) -> str:
if cfg.get("custom_provider"):
return f"{cfg['custom_provider']}/{cfg['model_name']}"
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
return f"{prefix}/{cfg['model_name']}"
# ---------------------------------------------------------------------------
# Live probes (network calls)
# ---------------------------------------------------------------------------
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
model_string = _build_chat_model_string(cfg)
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=model_string.split("/", 1)[0],
config_api_base=cfg.get("api_base") or None,
)
kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "reply with one word: ok"},
{
"type": "image_url",
"image_url": {"url": _TINY_PNG_DATA_URL},
},
],
}
],
"max_tokens": 16,
"timeout": 60,
}
if api_base:
kwargs["api_base"] = api_base
if cfg.get("litellm_params"):
# Strip pricing keys — they're tracking-only and confuse some
# provider validators (e.g. azure/openai reject unknown kwargs
# in strict mode).
merged = {
k: v
for k, v in dict(cfg["litellm_params"]).items()
if k
not in {
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_pixel",
"output_cost_per_pixel",
}
}
kwargs.update(merged)
try:
resp = await litellm.acompletion(**kwargs)
except Exception as exc:
return False, f"{type(exc).__name__}: {exc}"
text = resp.choices[0].message.content if resp.choices else ""
return True, f"got reply ({(text or '').strip()[:40]!r})"
# Gemini image models occasionally return zero-length ``data`` for the
# minimal "red dot on white" prompt (provider-side safety / empty-output
# quirk reproducible against ``google/gemini-2.5-flash-image`` even when
# the request itself succeeds). Use a more naturalistic prompt and
# retry once with a different one before giving up.
_IMAGE_GEN_PROMPTS: tuple[str, ...] = (
"A simple icon of a coffee cup, flat illustration",
"A small green leaf on a white background",
)
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
"""Generate one tiny image to verify the deployment is reachable."""
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
if cfg.get("custom_provider"):
prefix = cfg["custom_provider"]
else:
prefix = _PROVIDER_PREFIX_MAP.get(
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
)
model_string = f"{prefix}/{cfg['model_name']}"
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=prefix,
config_api_base=cfg.get("api_base") or None,
)
base_kwargs: dict[str, Any] = {
"model": model_string,
"api_key": cfg.get("api_key"),
"n": 1,
"size": "1024x1024",
"timeout": 120,
}
if api_base:
base_kwargs["api_base"] = api_base
if cfg.get("api_version"):
base_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
base_kwargs.update(
{
k: v
for k, v in dict(cfg["litellm_params"]).items()
if k
not in {
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_pixel",
"output_cost_per_pixel",
}
}
)
last_note = ""
for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1):
try:
resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs)
except Exception as exc:
last_note = f"{type(exc).__name__}: {exc}"
continue
data_count = len(getattr(resp, "data", None) or [])
if data_count > 0:
return True, (
f"received {data_count} image(s) on attempt {attempt} "
f"(prompt={prompt!r})"
)
last_note = (
f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})"
)
return False, last_note
# ---------------------------------------------------------------------------
# Probe drivers
# ---------------------------------------------------------------------------
def _is_or_dynamic(cfg: dict) -> bool:
return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER))
async def probe_chat_configs(report: Report, *, live: bool) -> None:
print("\n[chat configs from global_llm_configs (YAML-static)]")
for cfg in config.GLOBAL_LLM_CONFIGS:
# Skip OR dynamic entries here — handled in the OR section so
# the YAML / OR split stays clear in the report.
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="chat-yaml",
config_id=cfg.get("id"),
)
cap_ok, cap_note = _probe_chat_capability(cfg)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await _live_chat_image_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_vision_configs(report: Report, *, live: bool) -> None:
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="vision",
config_id=cfg.get("id"),
)
# For vision configs, capability is implied — they're in the
# dedicated vision pool. Run the same resolver to flag any
# surprise disagreement.
cap_ok, cap_note = _probe_chat_capability(cfg)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await _live_chat_image_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
print(
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
)
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if _is_or_dynamic(cfg):
continue
result = ProbeResult(
label=str(cfg.get("name") or cfg.get("model_name")),
surface="image-gen",
config_id=cfg.get("id"),
)
# Image gen configs don't have a "supports_image_input" flag;
# the catalog tracks output, not input. Mark capability as None
# (skip) for the report.
if live:
t0 = time.perf_counter()
ok, note = await _live_image_gen_call(cfg)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
"""Sample one chat (vision-capable), one vision, one image-gen model
from the live OpenRouter catalogue. Doesn't iterate the full pool
(would be hundreds of probes); just validates the integration end-
to-end on a representative model from each surface."""
print("\n[OpenRouter integration: sampled probes]")
settings = config.OPENROUTER_INTEGRATION_SETTINGS
if not settings:
report.add(
ProbeResult(
label="OpenRouter integration",
surface="openrouter",
config_id="settings",
capability_ok=None,
capability_note="openrouter_integration disabled in YAML — skipping",
live_ok=None,
)
)
return
service = OpenRouterIntegrationService.get_instance()
or_chat = [
c
for c in config.GLOBAL_LLM_CONFIGS
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
]
or_vision = [
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
]
or_image_gen = [
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
]
# Pick one representative per provider family per surface so a single
# broken vendor (e.g. Anthropic key revoked, Google quota exceeded)
# surfaces independently of the others. Each needle matches the
# OpenRouter ``model_name`` prefix; the first match wins.
def _pick_first(pool: list[dict], needle: str) -> dict | None:
for c in pool:
if (c.get("model_name") or "").lower().startswith(needle):
return c
return None
chat_picks = [
("or-chat", _pick_first(or_chat, "openai/gpt-4o")),
("or-chat", _pick_first(or_chat, "anthropic/claude")),
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
]
vision_picks = [
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
("or-vision", _pick_first(or_vision, "anthropic/claude")),
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
]
image_picks = [
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
# / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match
# the actual prefix.
("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")),
]
print(
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
)
for surface, picked in chat_picks + vision_picks + image_picks:
if not picked:
report.add(
ProbeResult(
label=f"<no candidate for {surface}>",
surface=surface,
config_id="-",
capability_ok=None,
capability_note="no candidate found in OR catalog",
)
)
continue
runner = (
_live_image_gen_call if surface == "or-image" else _live_chat_image_call
)
result = ProbeResult(
label=str(picked.get("model_name")),
surface=surface,
config_id=picked.get("id"),
)
if surface != "or-image":
cap_ok, cap_note = _probe_chat_capability(picked)
result.capability_ok = cap_ok
result.capability_note = cap_note
if live:
t0 = time.perf_counter()
ok, note = await runner(picked)
result.live_ok = ok
result.live_note = note
result.duration_s = time.perf_counter() - t0
report.add(result)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
async def main(args: argparse.Namespace) -> int:
print("Loaded global configs:")
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
# Initialize the OpenRouter integration so the catalog is populated
# (this is what main.py does at startup). It's idempotent.
if config.OPENROUTER_INTEGRATION_SETTINGS:
try:
from app.config import initialize_openrouter_integration
initialize_openrouter_integration()
except Exception as exc:
print(f" WARNING: OpenRouter integration init failed: {exc}")
print(
f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}"
)
report = Report()
if not args.skip_chat:
await probe_chat_configs(report, live=args.live)
if not args.skip_vision:
await probe_vision_configs(report, live=args.live)
if not args.skip_image_gen:
await probe_image_gen_configs(report, live=args.live)
if not args.skip_openrouter:
await probe_openrouter_catalog(report, live=args.live)
failed = report.render()
return 1 if failed else 0
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--no-live",
dest="live",
action="store_false",
help="Skip live API calls — capability resolver only.",
)
parser.set_defaults(live=True)
parser.add_argument("--skip-chat", action="store_true")
parser.add_argument("--skip-vision", action="store_true")
parser.add_argument("--skip-image-gen", action="store_true")
parser.add_argument("--skip-openrouter", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
sys.exit(asyncio.run(main(args)))

View file

@ -0,0 +1,110 @@
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
There is no DB column for ``supports_image_input`` on
``NewLLMConfig`` the value is resolved at the API boundary by
``derive_supports_image_input`` so the new-chat selector / streaming
task can read the same field shape regardless of source (BYOK vs YAML
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
user out of their own model choice.
"""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from uuid import uuid4
import pytest
from app.db import LiteLLMProvider
from app.routes import new_llm_config_routes
pytestmark = pytest.mark.unit
def _byok_row(
*,
id_: int,
model_name: str,
base_model: str | None = None,
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
custom_provider: str | None = None,
) -> object:
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
enum validator accepts it same as the ORM row would carry."""
return SimpleNamespace(
id=id_,
name=f"BYOK-{id_}",
description=None,
provider=provider,
custom_provider=custom_provider,
model_name=model_name,
api_key="sk-byok",
api_base=None,
litellm_params={"base_model": base_model} if base_model else None,
system_instructions="",
use_default_system_instructions=True,
citations_enabled=True,
created_at=datetime.now(tz=UTC),
search_space_id=42,
user_id=uuid4(),
)
def test_serialize_byok_known_vision_model_resolves_true():
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
True. The serialized row carries that value through to the
``NewLLMConfigRead`` schema."""
row = _byok_row(id_=1, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
assert serialized.id == 1
assert serialized.model_name == "gpt-4o"
def test_serialize_byok_unknown_model_default_allows():
"""Unknown / unmapped: default-allow. The streaming-task safety net
is the actual block, and it requires LiteLLM to *explicitly* say
text-only so a brand new BYOK model should not be pre-judged."""
row = _byok_row(
id_=2,
model_name="brand-new-model-x9-unmapped",
provider=LiteLLMProvider.CUSTOM,
custom_provider="brand_new_proxy",
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_uses_base_model_when_present():
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
helper must consult ``base_model`` first or unrecognised deployment
ids would shadow the real capability."""
row = _byok_row(
id_=3,
model_name="my-azure-deployment-id-no-litellm-knows-this",
base_model="gpt-4o",
provider=LiteLLMProvider.AZURE_OPENAI,
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_returns_pydantic_read_model():
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
the schema additions are guaranteed to be present in the API
surface. This guards against a future regression where someone
deletes the augmentation step and falls back to ORM passthrough."""
from app.schemas import NewLLMConfigRead
row = _byok_row(id_=4, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert isinstance(serialized, NewLLMConfigRead)

View file

@ -0,0 +1,184 @@
"""Unit tests for ``is_premium`` derivation on the global image-gen and
vision-LLM list endpoints.
Chat globals (``GET /global-llm-configs``) already emit
``is_premium = (billing_tier == "premium")``. Image and vision did not,
which made the new-chat ``model-selector`` render the Free/Premium badge
on the Chat tab but skip it on the Image and Vision tabs (the selector
keys its badge logic off ``is_premium``). These tests pin parity:
* YAML free entry ``is_premium=False``
* YAML premium entry ``is_premium=True``
* OpenRouter dynamic premium entry ``is_premium=True``
* Auto stub (always emitted when at least one config is present)
``is_premium=False``
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_IMAGE_FIXTURE: list[dict] = [
{
"id": -1,
"name": "DALL-E 3",
"provider": "OPENAI",
"model_name": "dall-e-3",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "GPT-Image 1 (premium)",
"provider": "OPENAI",
"model_name": "gpt-image-1",
"api_key": "sk-test",
"billing_tier": "premium",
},
{
"id": -20_001,
"name": "google/gemini-2.5-flash-image (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
_VISION_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o Vision",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "Claude 3.5 Sonnet (premium)",
"provider": "ANTHROPIC",
"model_name": "claude-3-5-sonnet",
"api_key": "sk-ant-test",
"billing_tier": "premium",
},
{
"id": -30_001,
"name": "openai/gpt-4o (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
# =============================================================================
# Image generation
# =============================================================================
@pytest.mark.asyncio
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
"""Each emitted config must carry ``is_premium`` derived server-side
from ``billing_tier``. The Auto stub is always free.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub is always emitted when at least one global config exists,
# and it must always declare itself free (Auto-mode billing-tier
# surfacing is a separate follow-up).
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
# YAML free entry — ``is_premium=False``
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
# YAML premium entry — ``is_premium=True``
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
# OpenRouter dynamic premium entry — same field, same derivation
assert by_id[-20_001]["is_premium"] is True
assert by_id[-20_001]["billing_tier"] == "premium"
# Every emitted dict (including Auto) must have the field — never missing.
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
"""When there are no global configs at all, the endpoint emits an
empty list (no Auto stub) Auto mode would have nothing to route to.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
assert payload == []
# =============================================================================
# Vision LLM
# =============================================================================
@pytest.mark.asyncio
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
assert by_id[-30_001]["is_premium"] is True
assert by_id[-30_001]["billing_tier"] == "premium"
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
assert payload == []

View file

@ -0,0 +1,106 @@
"""Unit tests for ``supports_image_input`` derivation on the chat global
config endpoint (``GET /global-new-llm-configs``).
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
loader for operator overrides, or by the OpenRouter integration from
``architecture.input_modalities``) wins.
2. ``derive_supports_image_input`` helper default-allow on unknown
models, only False when LiteLLM / OR modalities are definitive.
The flag is purely informational at the API boundary. The streaming
task safety net (``is_known_text_only_chat_model``) is the actual block,
and it requires LiteLLM to *explicitly* mark the model as text-only.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o (explicit true)",
"description": "vision-capable, explicit YAML override",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
"supports_image_input": True,
},
{
"id": -2,
"name": "DeepSeek V3 (explicit false)",
"description": "OpenRouter dynamic — modality-derived false",
"provider": "OPENROUTER",
"model_name": "deepseek/deepseek-v3.2-exp",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "free",
"supports_image_input": False,
},
{
"id": -10_010,
"name": "Unannotated GPT-4o",
"description": "no flag set — resolver should derive True via LiteLLM",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
# supports_image_input intentionally absent
},
{
"id": -10_011,
"name": "Unannotated unknown model",
"description": "unmapped — default-allow True",
"provider": "CUSTOM",
"custom_provider": "brand_new_proxy",
"model_name": "brand-new-model-x9",
"api_key": "sk-test",
"billing_tier": "free",
},
]
@pytest.mark.asyncio
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
"""Each emitted chat config carries ``supports_image_input`` as a
bool. Explicit values win; unannotated entries are resolved via the
helper (default-allow True)."""
from app.config import config
from app.routes import new_llm_config_routes
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub: optimistic True so the user can keep Auto selected with
# vision-capable deployments somewhere in the pool.
assert 0 in by_id, "Auto stub should be emitted when configs exist"
assert by_id[0]["supports_image_input"] is True
assert by_id[0]["is_auto_mode"] is True
# Explicit True is preserved.
assert by_id[-1]["supports_image_input"] is True
# Explicit False is preserved (the exact failure mode the safety net
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
# endpoints found that support image input").
assert by_id[-2]["supports_image_input"] is False
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
assert by_id[-10_010]["supports_image_input"] is True
# Unknown / unmapped model: default-allow rather than pre-judge.
assert by_id[-10_011]["supports_image_input"] is True
for cfg in payload:
assert "supports_image_input" in cfg, (
f"supports_image_input missing from {cfg.get('id')}"
)
assert isinstance(cfg["supports_image_input"], bool)

View file

@ -0,0 +1,286 @@
"""Image-aware extension of the Auto-pin resolver.
When the current chat turn carries an ``image_url`` block, the pin
resolver must:
1. Filter the candidate pool to vision-capable cfgs so a freshly
selected pin can never be text-only.
2. Treat any existing pin whose capability is False as invalid (force
re-pin), even when it would otherwise be reused as the thread's
stable model.
3. Raise ``ValueError`` (mapped to the friendly
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
task) when no vision-capable cfg is available instead of silently
pinning text-only and 404-ing at the provider.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
import pytest
from app.services.auto_model_pin_service import (
clear_healthy,
clear_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _reset_caches():
clear_runtime_cooldown()
clear_healthy()
yield
clear_runtime_cooldown()
clear_healthy()
@dataclass
class _FakeQuotaResult:
allowed: bool
class _FakeExecResult:
def __init__(self, thread):
self._thread = thread
def unique(self):
return self
def scalar_one_or_none(self):
return self._thread
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
async def commit(self):
self.commit_count += 1
def _thread(*, pinned: int | None = None):
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"model_name": f"vision-{id_}",
"api_key": "k",
"billing_tier": tier,
"supports_image_input": True,
"auto_pin_tier": "A",
"quality_score": quality,
}
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"model_name": f"text-{id_}",
"api_key": "k",
"billing_tier": tier,
# Higher quality than the vision cfgs — so a bug that ignores
# the image flag would surface as the resolver picking this one.
"supports_image_input": False,
"auto_pin_tier": "A",
"quality_score": quality,
}
async def _premium_allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
@pytest.mark.asyncio
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
# The thread should be pinned to the vision cfg even though the
# text-only cfg has a higher quality score.
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
"""An existing text-only pin must be invalidated when the next turn
requires image input. The non-image path would happily reuse it."""
from app.config import config
session = _FakeSession(_thread(pinned=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
"""If the thread is already pinned to a vision-capable cfg, reuse it
same as the non-image path. Image-aware filtering must not force
spurious re-pins."""
from app.config import config
session = _FakeSession(_thread(pinned=-2))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
"""The friendly-error path: no vision-capable cfg in the pool -> raise
``ValueError`` whose message contains ``vision-capable`` so the
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _text_only_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
with pytest.raises(ValueError, match="vision-capable"):
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
@pytest.mark.asyncio
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
"""Regression guard: the image flag must default False and not affect
a normal text-only turn text-only cfgs remain selectable."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
@pytest.mark.asyncio
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
"""A YAML cfg that omits ``supports_image_input`` falls through to
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
that returns True, so the cfg should be a valid candidate."""
from app.config import config
session = _FakeSession(_thread())
cfg_unannotated_vision = {
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-4o", # known vision model in LiteLLM map
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "A",
"quality_score": 80,
# NOTE: no supports_image_input key
}
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2

View file

@ -15,6 +15,7 @@ vision LLM extraction:
from __future__ import annotations
import asyncio
import contextlib
from typing import Any
from uuid import uuid4
@ -57,6 +58,9 @@ class _FakeSession:
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
pass
async def close(self) -> None:
pass
@ -71,7 +75,9 @@ async def _fake_shielded_session():
_SESSIONS_USED: list[_FakeSession] = []
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
def _patch_isolation_layer(
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
@ -91,6 +97,8 @@ def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None)
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
if finalize_exc is not None:
raise finalize_exc
finalize_calls.append(
{
"user_id": user_id,
@ -343,6 +351,125 @@ async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
assert spies["reserve"][0]["reserve_micros"] == 12_345
@pytest.mark.asyncio
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
from app.services.billable_calls import BillingSettlementError, billable_call
class _FinalizeError(RuntimeError):
pass
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(allowed=True),
finalize_exc=_FinalizeError("db finalize failed"),
)
user_id = uuid4()
with pytest.raises(BillingSettlementError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert len(spies["release"]) == 1
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _HangingCommitSession(_FakeSession):
async def commit(self) -> None:
await asyncio.sleep(60)
@contextlib.asynccontextmanager
async def _hanging_session_factory():
s = _HangingCommitSession()
_SESSIONS_USED.append(s)
yield s
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
billable_session_factory=_hanging_session_factory,
audit_timeout_seconds=0.01,
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert len(spies["finalize"]) == 1
assert len(spies["record"]) == 1
assert spies["release"] == []
@pytest.mark.asyncio
async def test_free_audit_failure_is_best_effort(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
async def _failing_record(_session, **_kwargs):
raise RuntimeError("audit insert failed")
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_failing_record,
raising=False,
)
async with billable_call(
user_id=uuid4(),
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
audit_timeout_seconds=0.01,
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@ -387,7 +514,7 @@ async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
assert row["thread_id"] == 99
assert row["thread_id"] is None
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}

View file

@ -0,0 +1,177 @@
"""Defense-in-depth: image-gen call sites must not let an empty
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
The bug repro: an OpenRouter image-gen config ships
``api_base=""``. The pre-fix call site in
``image_generation_routes._execute_image_generation`` did
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
silently dropped the empty string. LiteLLM then fell back to
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
and OpenRouter's ``image_generation/transformation`` appended
``/chat/completions`` to it 404 ``Resource not found``.
This test pins the post-fix behaviour: with an empty ``api_base`` in
the config, the call site MUST set ``api_base`` to OpenRouter's public
URL instead of leaving it unset.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
"""The global-config branch (``config_id < 0``) of
``_execute_image_generation`` must apply the resolver and pin
``api_base`` to OpenRouter when the config ships an empty string.
"""
from app.routes import image_generation_routes
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "", # the original bug shape
"api_version": None,
"litellm_params": {},
}
captured: dict = {}
async def fake_aimage_generation(**kwargs):
captured.update(kwargs)
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
image_gen = MagicMock()
image_gen.image_generation_config_id = cfg["id"]
image_gen.prompt = "test"
image_gen.n = 1
image_gen.quality = None
image_gen.size = None
image_gen.style = None
image_gen.response_format = None
image_gen.model = None
search_space = MagicMock()
search_space.image_generation_config_id = cfg["id"]
session = MagicMock()
with (
patch.object(
image_generation_routes,
"_get_global_image_gen_config",
return_value=cfg,
),
patch.object(
image_generation_routes,
"aimage_generation",
side_effect=fake_aimage_generation,
),
):
await image_generation_routes._execute_image_generation(
session=session, image_gen=image_gen, search_space=search_space
)
# The whole point of the fix: even with empty ``api_base`` in the
# config, we forward OpenRouter's public URL so the call doesn't
# inherit an Azure endpoint.
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
@pytest.mark.asyncio
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
"""Same defense at the agent tool entry point — both surfaces share
the same OpenRouter config payloads."""
from app.agents.new_chat.tools import generate_image as gi_module
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
}
captured: dict = {}
async def fake_aimage_generation(**kwargs):
captured.update(kwargs)
response = MagicMock()
response.model_dump.return_value = {
"data": [{"url": "https://example.com/x.png"}]
}
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
return response
search_space = MagicMock()
search_space.id = 1
search_space.image_generation_config_id = cfg["id"]
session_cm = AsyncMock()
session = AsyncMock()
session_cm.__aenter__.return_value = session
scalars = MagicMock()
scalars.first.return_value = search_space
exec_result = MagicMock()
exec_result.scalars.return_value = scalars
session.execute.return_value = exec_result
session.add = MagicMock()
session.commit = AsyncMock()
session.refresh = AsyncMock()
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
async def _refresh(obj):
obj.id = 1
session.refresh.side_effect = _refresh
with (
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
patch.object(
gi_module, "aimage_generation", side_effect=fake_aimage_generation
),
patch.object(
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
),
):
tool = gi_module.create_generate_image_tool(
search_space_id=1, db_session=MagicMock()
)
await tool.ainvoke({"prompt": "a cat", "n": 1})
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
"""The Auto-mode router pool must also resolve ``api_base`` when an
OpenRouter config ships an empty string. The deployment dict is fed
straight to ``litellm.Router``, so a missing ``api_base`` would
leak the same way as the direct call sites.
"""
from app.services.image_gen_router_service import ImageGenRouterService
deployment = ImageGenRouterService._config_to_deployment(
{
"model_name": "openai/gpt-image-1",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"

View file

@ -265,6 +265,10 @@ def test_generate_image_gen_configs_filters_by_image_output():
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't 404 against an inherited Azure endpoint.
assert c["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_image_gen_configs_assigns_image_id_offset():
@ -342,6 +346,10 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output():
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't inherit an Azure endpoint.
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_vision_llm_configs_drops_chat_only_filters():

View file

@ -0,0 +1,107 @@
"""Unit tests for the shared ``api_base`` resolver.
The cascade exists so vision and image-gen call sites can't silently
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
when an OpenRouter / Groq / etc. config ships an empty string. See
``provider_api_base`` module docstring for the original repro
(OpenRouter image-gen 404-ing against an Azure endpoint).
"""
from __future__ import annotations
import pytest
from app.services.provider_api_base import (
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
pytestmark = pytest.mark.unit
def test_config_value_wins_over_defaults():
"""A non-empty config value is always returned verbatim, even when the
provider has a default the operator gets the last word."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="https://my-openrouter-mirror.example.com/v1",
)
assert result == "https://my-openrouter-mirror.example.com/v1"
def test_provider_key_default_when_config_missing():
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
base URL the provider-key map must take precedence over the prefix
map so DeepSeek requests don't go to OpenAI."""
result = resolve_api_base(
provider="DEEPSEEK",
provider_prefix="openai",
config_api_base=None,
)
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_provider_prefix_default_when_no_key_default():
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=None,
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_unknown_provider_returns_none():
"""When neither map matches we return ``None`` so the caller can let
LiteLLM apply its own provider-integration default (Azure deployment
URL, custom-provider URL, etc.)."""
result = resolve_api_base(
provider="SOMETHING_NEW",
provider_prefix="something_new",
config_api_base=None,
)
assert result is None
def test_empty_string_config_treated_as_missing():
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
and downstream call sites use ``if cfg.get("api_base"):`` empty
strings are falsy in Python but the cascade has to step in anyway."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_whitespace_only_config_treated_as_missing():
"""A config value of ``" "`` is a configuration mistake — treat it
as missing instead of forwarding whitespace to LiteLLM (which would
almost certainly 404)."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=" ",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_provider_case_insensitive():
"""Some call sites pass the provider lowercase (DB enum value), others
uppercase (YAML key). Both must resolve."""
upper = resolve_api_base(
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
)
lower = resolve_api_base(
provider="deepseek", provider_prefix="openai", config_api_base=None
)
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_all_inputs_none_returns_none():
assert (
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
is None
)

View file

@ -0,0 +1,244 @@
"""Unit tests for the shared chat-image capability resolver.
Two resolvers, two intents:
- ``derive_supports_image_input`` best-effort True for the catalog and
selector. Default-allow on unknown / unmapped models. The streaming
task safety net never sees this value directly.
- ``is_known_text_only_chat_model`` strict opt-out for the safety net.
Returns True only when LiteLLM's model map *explicitly* sets
``supports_vision=False``. Anything else (missing key, exception,
True) returns False so the request flows through to the provider.
"""
from __future__ import annotations
import pytest
from app.services.provider_capabilities import (
derive_supports_image_input,
is_known_text_only_chat_model,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# derive_supports_image_input — OpenRouter modalities path (authoritative)
# ---------------------------------------------------------------------------
def test_or_modalities_with_image_returns_true():
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="openai/gpt-4o",
openrouter_input_modalities=["text", "image"],
)
is True
)
def test_or_modalities_text_only_returns_false():
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="deepseek/deepseek-v3.2-exp",
openrouter_input_modalities=["text"],
)
is False
)
def test_or_modalities_empty_list_returns_false():
"""OR explicitly publishing an empty modality list is a definitive
'no inputs at all' signal treat as False rather than falling back
to LiteLLM."""
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="weird/empty-modalities",
openrouter_input_modalities=[],
)
is False
)
def test_or_modalities_none_falls_through_to_litellm():
"""``None`` (missing key) is *not* a definitive signal — fall through
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
assert (
derive_supports_image_input(
provider="OPENAI",
model_name="gpt-4o",
openrouter_input_modalities=None,
)
is True
)
# ---------------------------------------------------------------------------
# derive_supports_image_input — LiteLLM model-map path
# ---------------------------------------------------------------------------
def test_litellm_known_vision_model_returns_true():
assert (
derive_supports_image_input(
provider="OPENAI",
model_name="gpt-4o",
)
is True
)
def test_litellm_base_model_wins_over_model_name():
"""Azure-style entries pass model_name=deployment_id and put the
canonical sku in litellm_params.base_model. The resolver must
consult base_model first or the deployment id (which LiteLLM
doesn't know) would shadow the real capability."""
assert (
derive_supports_image_input(
provider="AZURE_OPENAI",
model_name="my-azure-deployment-id",
base_model="gpt-4o",
)
is True
)
def test_litellm_unknown_model_default_allows():
"""Default-allow on unknown — the safety net is the actual block."""
assert (
derive_supports_image_input(
provider="CUSTOM",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
is True
)
def test_litellm_known_text_only_returns_false():
"""A model that LiteLLM explicitly knows is text-only resolves to
False even via the catalog resolver. ``deepseek-chat`` (the
DeepSeek-V3 chat sku) is in the map without supports_vision and
LiteLLM's `supports_vision` returns False."""
# Sanity: confirm the helper's negative path. We use a small model
# known not to support vision per the map.
result = derive_supports_image_input(
provider="DEEPSEEK",
model_name="deepseek-chat",
)
# We accept either False (LiteLLM said explicit no) or True
# (default-allow if the entry isn't mapped on this version) — the
# invariant is that the resolver never *raises* on a known-text-only
# provider/model. The behaviour-binding assertion lives in
# ``test_is_known_text_only_chat_model_explicit_false`` below.
assert isinstance(result, bool)
# ---------------------------------------------------------------------------
# is_known_text_only_chat_model — strict opt-out semantics
# ---------------------------------------------------------------------------
def test_is_known_text_only_returns_false_for_vision_model():
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_is_known_text_only_returns_false_for_unknown_model():
"""Strict opt-out: missing from the map ≠ text-only. The safety net
must NOT fire for an unmapped model that's the regression we're
fixing."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
is False
)
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
helper swallows the exception and returns False so the safety net
doesn't fire on a transient lookup failure."""
import app.services.provider_capabilities as pc
def _raise(**_kwargs):
raise ValueError("intentional test failure")
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
we exercise the opt-out path deterministically. Using a stub keeps
the test stable across LiteLLM map updates."""
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"supports_vision": False, "max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is True
)
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"supports_vision": True}
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is False
)
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
"""A model entry without ``supports_vision`` at all is treated as
'unknown' strict opt-out means False."""
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"max_input_tokens": 8192} # no supports_vision
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is False
)

View file

@ -0,0 +1,281 @@
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
Capability is sourced from two places, in order of preference:
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
(authoritative OpenRouter publishes per-model modalities directly).
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
YAML / BYOK configs that don't carry an explicit operator override.
The catalog default is *True* (conservative-allow): an unknown / unmapped
model is not pre-judged. The streaming-task safety net
(``is_known_text_only_chat_model``) is the only place a False actually
blocks a request and it requires LiteLLM to *explicitly* mark the model
as text-only.
"""
from __future__ import annotations
import pytest
from app.services.openrouter_integration_service import (
_OPENROUTER_DYNAMIC_MARKER,
_generate_configs,
_supports_image_input,
)
pytestmark = pytest.mark.unit
_SETTINGS_BASE: dict = {
"api_key": "sk-or-test",
"id_offset": -10_000,
"rpm": 200,
"tpm": 1_000_000,
"free_rpm": 20,
"free_tpm": 100_000,
"anonymous_enabled_paid": False,
"anonymous_enabled_free": True,
"quota_reserve_tokens": 4000,
}
# ---------------------------------------------------------------------------
# _supports_image_input helper (OpenRouter modalities)
# ---------------------------------------------------------------------------
def test_supports_image_input_true_for_multimodal():
assert (
_supports_image_input(
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
}
)
is True
)
def test_supports_image_input_false_for_text_only():
"""The exact failure mode the safety net guards against — DeepSeek V3
is a text-in/text-out model and would 404 if forwarded image_url."""
assert (
_supports_image_input(
{
"id": "deepseek/deepseek-v3.2-exp",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
}
)
is False
)
def test_supports_image_input_false_when_modalities_missing():
"""Defensive: missing architecture is treated as text-only at the
OpenRouter helper level. The wider catalog resolver
(`derive_supports_image_input`) only consults modalities when they
are non-empty, otherwise it falls back to LiteLLM."""
assert _supports_image_input({"id": "weird/model"}) is False
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
assert (
_supports_image_input(
{"id": "weird/model", "architecture": {"input_modalities": None}}
)
is False
)
# ---------------------------------------------------------------------------
# _generate_configs threads the flag onto every emitted chat config
# ---------------------------------------------------------------------------
def test_generate_configs_emits_supports_image_input():
raw = [
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
{
"id": "deepseek/deepseek-v3.2-exp",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
},
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
by_model = {c["model_name"]: c for c in cfgs}
gpt = by_model["openai/gpt-4o"]
assert gpt["supports_image_input"] is True
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
assert deepseek["supports_image_input"] is False
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
# ---------------------------------------------------------------------------
# YAML loader: defer to derive_supports_image_input on unannotated entries
# ---------------------------------------------------------------------------
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
"""The regression case: an Azure GPT-5.x YAML entry without a
``supports_image_input`` override should resolve to True via LiteLLM's
model map (which says ``supports_vision: true``). Previously this
defaulted to False, blocking every image turn for vision-capable
YAML configs."""
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -2
name: Azure GPT-4o
provider: AZURE_OPENAI
model_name: gpt-4o
api_key: sk-test
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
assert configs[0]["supports_image_input"] is True
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -1
name: GPT-4o
provider: OPENAI
model_name: gpt-4o
api_key: sk-test
supports_image_input: false
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
# Operator override always wins, even against LiteLLM's True.
assert configs[0]["supports_image_input"] is False
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
"""Unknown / unmapped model in YAML: default-allow. The streaming
safety net (which requires an explicit-False from LiteLLM) is the
only place a real block happens, so we don't lock the user out of
a freshly added third-party entry the catalog can't introspect."""
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -1
name: Some Brand New Model
provider: CUSTOM
custom_provider: brand_new_proxy
model_name: brand-new-model-x9
api_key: sk-test
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
assert configs[0]["supports_image_input"] is True
# ---------------------------------------------------------------------------
# AgentConfig threads the flag through both YAML and Auto / BYOK
# ---------------------------------------------------------------------------
def test_agent_config_from_yaml_explicit_overrides_resolver():
from app.agents.new_chat.llm_config import AgentConfig
cfg_text_only = AgentConfig.from_yaml_config(
{
"id": -1,
"name": "Text Only Override",
"provider": "openai",
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
"api_key": "sk-test",
"supports_image_input": False,
}
)
cfg_explicit_vision = AgentConfig.from_yaml_config(
{
"id": -2,
"name": "GPT-4o",
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
"supports_image_input": True,
}
)
assert cfg_text_only.supports_image_input is False
assert cfg_explicit_vision.supports_image_input is True
def test_agent_config_from_yaml_unannotated_uses_resolver():
"""Without an explicit YAML key, AgentConfig defers to the catalog
resolver for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
from app.agents.new_chat.llm_config import AgentConfig
cfg = AgentConfig.from_yaml_config(
{
"id": -1,
"name": "GPT-4o (no override)",
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
}
)
assert cfg.supports_image_input is True
def test_agent_config_auto_mode_supports_image_input():
"""Auto routes across the pool. We optimistically allow image input
so users can keep their selection on Auto with a vision-capable
deployment somewhere in the pool. The router's own `allowed_fails`
handles non-vision deployments via fallback."""
from app.agents.new_chat.llm_config import AgentConfig
auto = AgentConfig.from_auto_mode()
assert auto.supports_image_input is True

View file

@ -0,0 +1,89 @@
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
defaults from ``litellm.api_base`` either.
Vision shares the same shape as image-gen global YAML / OpenRouter
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
call sites would silently drop the empty string and inherit
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
construction so we test the kwargs we hand to it instead.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_get_vision_llm_global_openrouter_sets_api_base():
"""Global negative-ID branch: an OpenRouter vision config with
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
``api_base="https://openrouter.ai/api/v1"`` never an empty string,
never silently absent."""
from app.services import llm_service
cfg = {
"id": -30_001,
"name": "GPT-4o Vision (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
"billing_tier": "free",
}
search_space = MagicMock()
search_space.id = 1
search_space.user_id = "user-x"
search_space.vision_llm_config_id = cfg["id"]
session = AsyncMock()
scalars = MagicMock()
scalars.first.return_value = search_space
result = MagicMock()
result.scalars.return_value = scalars
session.execute.return_value = result
captured: dict = {}
class FakeSanitized:
def __init__(self, **kwargs):
captured.update(kwargs)
with (
patch(
"app.services.vision_llm_router_service.get_global_vision_llm_config",
return_value=cfg,
),
patch(
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
new=FakeSanitized,
),
):
await llm_service.get_vision_llm(session=session, search_space_id=1)
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-4o"
def test_vision_router_deployment_sets_api_base_when_config_empty():
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
so the resolver has to apply at deployment construction time too."""
from app.services.vision_llm_router_service import VisionLLMRouterService
deployment = VisionLLMRouterService._config_to_deployment(
{
"model_name": "openai/gpt-4o",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"

View file

@ -0,0 +1,318 @@
"""Regression tests for ``run_async_celery_task``.
These tests pin down the production bug observed on 2026-05-02 where
the video-presentation Celery task hung at ``[billable_call] finalize``
because the shared ``app.db.engine`` had pooled asyncpg connections
bound to a *previous* task's now-closed event loop. Reusing such a
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
AttributeError: 'NoneType' object has no attribute 'send'
(the proactor is None because the loop is gone) and can hang forever
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
The fix is ``run_async_celery_task``: a small helper that runs every
async celery task body inside a fresh event loop and disposes the
shared engine pool both before (defends against a previous task that
crashed) and after (releases connections we opened on this loop).
Tests here exercise the helper with a stub engine that records
``dispose()`` calls and panics if a coroutine produced by one loop is
awaited on another mirroring the real asyncpg behaviour.
"""
from __future__ import annotations
import asyncio
import gc
import sys
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import patch
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Stub engine that emulates the asyncpg-on-stale-loop crash
# ---------------------------------------------------------------------------
class _StaleLoopEngine:
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
the running event loop id so tests can assert it ran on *each*
fresh loop.
"""
def __init__(self) -> None:
self.dispose_loop_ids: list[int] = []
async def dispose(self) -> None:
loop = asyncio.get_running_loop()
self.dispose_loop_ids.append(id(loop))
@contextmanager
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
"""Patch ``from app.db import engine as shared_engine`` lookup.
The helper imports lazily inside the function body, so we have to
patch the attribute on the already-loaded ``app.db`` module.
"""
import app.db as app_db
original = getattr(app_db, "engine", None)
app_db.engine = stub # type: ignore[attr-defined]
try:
yield
finally:
if original is None:
with pytest.raises(AttributeError):
_ = app_db.engine
else:
app_db.engine = original # type: ignore[attr-defined]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
"""Happy path: the coroutine result is returned, and the shared
engine is disposed both before and after the task body runs.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _body() -> str:
# Engine should already have been disposed once before we run.
assert len(stub.dispose_loop_ids) == 1
return "ok"
with _patch_shared_engine(stub):
result = run_async_celery_task(_body)
assert result == "ok"
# Once before the body, once after (in finally).
assert len(stub.dispose_loop_ids) == 2
# Both disposes ran on the SAME (fresh) loop the task body used.
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
def test_runner_creates_fresh_loop_per_invocation() -> None:
"""Each call must spin its own loop. Without this guarantee a
previous task's loop would be reused and the asyncpg-stale-loop
crash would never be avoided.
"""
import app.tasks.celery_tasks as celery_tasks_pkg
stub = _StaleLoopEngine()
new_loop_calls = 0
closed_loops: list[bool] = []
real_new_event_loop = asyncio.new_event_loop
def _counting_new_loop() -> asyncio.AbstractEventLoop:
nonlocal new_loop_calls
new_loop_calls += 1
loop = real_new_event_loop()
# Hook close() so we can verify each loop was closed properly
# before the next one was created.
original_close = loop.close
def _tracked_close() -> None:
closed_loops.append(True)
original_close()
loop.close = _tracked_close # type: ignore[method-assign]
return loop
async def _body() -> None:
# Loop is alive and current at body execution time.
running = asyncio.get_running_loop()
assert not running.is_closed()
with (
_patch_shared_engine(stub),
patch.object(asyncio, "new_event_loop", _counting_new_loop),
):
for _ in range(3):
celery_tasks_pkg.run_async_celery_task(_body)
assert new_loop_calls == 3
assert closed_loops == [True, True, True]
# Each invocation disposed twice (before + after).
assert len(stub.dispose_loop_ids) == 6
def test_runner_disposes_engine_even_when_body_raises() -> None:
"""Cleanup MUST run on the failure path too — otherwise stale
connections leak into the next task and cause the original hang.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
class _BoomError(RuntimeError):
pass
async def _body() -> None:
raise _BoomError("kaboom")
with _patch_shared_engine(stub), pytest.raises(_BoomError):
run_async_celery_task(_body)
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
def test_runner_swallows_dispose_errors() -> None:
"""A flaky engine.dispose() must NEVER take down a celery task.
Production scenario: the very first dispose (before the body runs)
might hit a partially-initialised engine; the helper logs and
moves on. The task body still runs; the result is still returned.
"""
from app.tasks.celery_tasks import run_async_celery_task
class _AngryEngine:
def __init__(self) -> None:
self.calls = 0
async def dispose(self) -> None:
self.calls += 1
raise RuntimeError("dispose() blew up")
stub = _AngryEngine()
async def _body() -> int:
return 42
with _patch_shared_engine(stub):
assert run_async_celery_task(_body) == 42
assert stub.calls == 2 # before + after both attempted
def test_runner_propagates_value_from_async_body() -> None:
"""Sanity: pass-through of any pickleable celery return value."""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _body() -> dict[str, object]:
return {"status": "ready", "video_presentation_id": 19}
with _patch_shared_engine(stub):
out = run_async_celery_task(_body)
assert out == {"status": "ready", "video_presentation_id": 19}
def test_video_presentation_task_uses_runner_helper() -> None:
"""Defence-in-depth: confirm the celery task module imports
``run_async_celery_task``. If a future refactor inlines a
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
the original hang will return.
"""
# The module's task body should not contain a manual new_event_loop
# call — that's exactly what the helper exists to centralise.
import inspect
from app.tasks.celery_tasks import video_presentation_tasks
src = inspect.getsource(video_presentation_tasks)
assert "run_async_celery_task" in src, (
"video_presentation_tasks.py must use run_async_celery_task; "
"manual asyncio.new_event_loop() in a celery task hangs on the "
"shared SQLAlchemy pool when reused across tasks."
)
assert "asyncio.new_event_loop" not in src, (
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
"call — route every async task through run_async_celery_task to "
"avoid the stale-pool hang."
)
def test_podcast_task_uses_runner_helper() -> None:
"""Symmetric assertion for the podcast task — same root cause, same
fix, same regression risk.
"""
import inspect
from app.tasks.celery_tasks import podcast_tasks
src = inspect.getsource(podcast_tasks)
assert "run_async_celery_task" in src
assert "asyncio.new_event_loop" not in src
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
"""If the task body created any async generators that didn't get
fully iterated, we must still call ``loop.shutdown_asyncgens()``
before closing otherwise we leak event-loop bound resources
that re-emerge as ``RuntimeError: Event loop is closed`` later.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _agen():
try:
yield 1
yield 2
finally:
pass
async def _body() -> None:
# Iterate the agen partially, then leave it dangling — exactly
# the situation shutdown_asyncgens() is designed to clean up.
async for v in _agen():
if v == 1:
break
with _patch_shared_engine(stub):
run_async_celery_task(_body)
# By the time the helper returns, garbage collection + shutdown_asyncgens
# should have ensured no live async-gen references remain. We don't
# assert agen.closed directly (it depends on GC ordering); the real
# contract is "no warnings, no event-loop-closed errors". A successful
# second invocation proves the loop was cleaned up properly.
with _patch_shared_engine(stub):
run_async_celery_task(_body)
# Force a GC pass to surface any 'coroutine was never awaited'
# warnings that would indicate the cleanup is broken.
gc.collect()
def test_runner_uses_proactor_loop_on_windows() -> None:
"""On Windows the celery worker preselects a Proactor policy so
subprocess (ffmpeg) calls work. The helper must not silently fall
back to a Selector loop and re-break video/podcast generation.
"""
if not sys.platform.startswith("win"):
pytest.skip("Windows-specific event-loop policy assertion")
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
# Mirror the policy set at the top of every Windows celery task.
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
observed: list[str] = []
async def _body() -> None:
observed.append(type(asyncio.get_running_loop()).__name__)
with _patch_shared_engine(stub):
run_async_celery_task(_body)
assert observed == ["ProactorEventLoop"]

View file

@ -113,6 +113,19 @@ async def _denying_billable_call(**kwargs):
yield SimpleNamespace() # pragma: no cover — for grammar only
@contextlib.asynccontextmanager
async def _settlement_failing_billable_call(**kwargs):
from app.services.billable_calls import BillingSettlementError
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
raise BillingSettlementError(
usage_type=kwargs.get("usage_type", "?"),
user_id=kwargs["user_id"],
cause=RuntimeError("finalize failed"),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@ -187,8 +200,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
)
assert call["thread_id"] == 99
assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
# FK to avoid coupling Celery audit commits to an active chat transaction.
assert "thread_id" not in call
assert call["call_details"] == {
"podcast_id": 7,
"title": "Test Podcast",
"thread_id": 99,
}
assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
@ -279,6 +299,49 @@ async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypat
assert graph_invoked == [] # Graph never ran on denied reservation.
@pytest.mark.asyncio
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=10)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(
podcast_tasks, "billable_call", _settlement_failing_billable_call
)
async def _fake_graph_invoke(state, config):
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=10,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 10,
"reason": "billing_settlement_failed",
}
assert podcast.status == PodcastStatus.FAILED
@pytest.mark.asyncio
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
"""If the resolver raises (e.g. search-space deleted), the task fails

View file

@ -0,0 +1,119 @@
"""Predicate-level test for the chat streaming safety net.
The safety net in ``stream_new_chat`` rejects an image turn early with
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
selected model is *known* to be text-only. The earlier round of this
work used a strict opt-in flag (``supports_image_input`` defaulting to
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
deployments this is the regression we're fixing.
The new predicate is :func:`is_known_text_only_chat_model`, which
returns True only when LiteLLM's authoritative model map *explicitly*
sets ``supports_vision=False``. Anything else (vision True, missing
key, exception) returns False so the request flows through to the
provider.
We exercise the predicate directly here rather than driving the full
``stream_new_chat`` generator covering the gate in isolation keeps
the test focused on the regression while the generator's wider behavior
is exercised by the integration suite.
"""
from __future__ import annotations
import pytest
from app.services.provider_capabilities import is_known_text_only_chat_model
pytestmark = pytest.mark.unit
def test_safety_net_does_not_fire_for_azure_gpt_4o():
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
vision-capable per LiteLLM's model map. The previous round's
blanket-False default blocked it; the new predicate must NOT mark
it text-only."""
assert (
is_known_text_only_chat_model(
provider="AZURE_OPENAI",
model_name="my-azure-deployment",
base_model="gpt-4o",
)
is False
)
def test_safety_net_does_not_fire_for_unknown_model():
"""Default-pass on unknown — the safety net only blocks definitive
text-only confirmations. A freshly added third-party model that
LiteLLM doesn't know about must flow through to the provider."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
custom_provider="brand_new_proxy",
model_name="brand-new-model-x9",
)
is False
)
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
"""Transient ``litellm.get_model_info`` exception ≠ block. The
helper swallows the error and treats it as 'unknown' False."""
import app.services.provider_capabilities as pc
def _raise(**_kwargs):
raise RuntimeError("intentional test failure")
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
"""Stub LiteLLM to assert the only path that returns True is the
explicit ``supports_vision=False`` case. Anything else (True,
None, missing key) returns False from the predicate."""
import app.services.provider_capabilities as pc
def _info_explicit_false(**_kwargs):
return {"supports_vision": False, "max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="text-only-stub",
)
is True
)
def _info_true(**_kwargs):
return {"supports_vision": True}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="vision-stub",
)
is False
)
def _info_missing(**_kwargs):
return {"max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="missing-key-stub",
)
is False
)

View file

@ -105,6 +105,19 @@ async def _denying_billable_call(**kwargs):
yield SimpleNamespace() # pragma: no cover
@contextlib.asynccontextmanager
async def _settlement_failing_billable_call(**kwargs):
from app.services.billable_calls import BillingSettlementError
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
raise BillingSettlementError(
usage_type=kwargs.get("usage_type", "?"),
user_id=kwargs["user_id"],
cause=RuntimeError("finalize failed"),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@ -176,11 +189,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
)
assert call["thread_id"] == 99
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
# FK to avoid coupling Celery audit commits to an active chat transaction.
assert "thread_id" not in call
assert call["call_details"] == {
"video_presentation_id": 11,
"title": "Test Presentation",
"thread_id": 99,
}
assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
@ -280,6 +297,57 @@ async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch
assert graph_invoked == []
@pytest.mark.asyncio
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=14)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(
video_presentation_tasks,
"billable_call",
_settlement_failing_billable_call,
)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=14,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 14,
"reason": "billing_settlement_failed",
}
assert video.status == VideoPresentationStatus.FAILED
@pytest.mark.asyncio
async def test_resolver_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus

View file

@ -477,9 +477,7 @@ const MessageInfoDropdown: FC = () => {
</span>
<span className="text-xs text-muted-foreground">
{counts.total_tokens.toLocaleString()} tokens
{costMicros && costMicros > 0
? ` · ${formatTurnCost(costMicros)}`
: ""}
{costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""}
</span>
</ActionBarMorePrimitive.Item>
);

View file

@ -19,6 +19,7 @@ import {
import type React from "react";
import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
import {
globalImageGenConfigsAtom,
imageGenConfigsAtom,
@ -461,6 +462,18 @@ export function ModelSelector({
const { data: visionUserConfigs, isLoading: visionUserLoading } =
useAtomValue(visionLLMConfigsAtom);
// Pending image attachments on the composer. Used to surface an
// amber "No image" hint on chat models the catalog reports as
// non-vision (`supports_image_input=false`) when the next message
// will carry an image. The hint is purely advisory: selection,
// focus, and click handling are unaffected. The backend's safety
// net (`is_known_text_only_chat_model`) is the actual block, and
// it only fires when LiteLLM *explicitly* marks a model as
// text-only — so a model that's secretly capable but hasn't been
// annotated will still flow through to the provider.
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
const hasPendingImages = pendingUserImageUrls.length > 0;
const isLoading =
llmUserLoading ||
llmGlobalLoading ||
@ -984,6 +997,21 @@ export function ModelSelector({
const isSelected = getSelectedId() === config.id;
const isFocused = focusedIndex === index;
const hasCitations = "citations_enabled" in config && !!config.citations_enabled;
// Chat-tab only: surface an amber "No image" hint when the
// composer carries images and the catalog reports the model as
// non-vision. This is purely advisory — selection is *not*
// blocked. The backend's narrow safety net
// (`is_known_text_only_chat_model`) is the source of truth for
// rejecting image turns, and it only fires when LiteLLM
// explicitly marks the model as text-only. A model surfaced as
// `supports_image_input=false` here may still be capable in
// practice (unknown / unmapped LiteLLM entry), so we let the
// user pick it and the provider response decide.
const isImageIncompatibleChatModel =
activeTab === "llm" &&
hasPendingImages &&
"supports_image_input" in config &&
(config as Record<string, unknown>).supports_image_input === false;
return (
<div
@ -992,6 +1020,11 @@ export function ModelSelector({
role="option"
tabIndex={isMobile ? -1 : 0}
aria-selected={isSelected}
title={
isImageIncompatibleChatModel
? "This model is reported as text-only. You can still pick it; the provider may reject image turns."
: undefined
}
onClick={() => handleSelectItem(item)}
onKeyDown={
isMobile
@ -1005,9 +1038,8 @@ export function ModelSelector({
}
onMouseEnter={() => setFocusedIndex(index)}
className={cn(
"group flex items-center gap-2.5 px-3 py-2 rounded-xl cursor-pointer",
"transition-all duration-150 mx-2",
"hover:bg-accent/40",
"group flex items-center gap-2.5 px-3 py-2 rounded-xl",
"transition-all duration-150 mx-2 cursor-pointer hover:bg-accent/40",
isSelected && "bg-primary/6 dark:bg-primary/8",
isFocused && "bg-accent/50"
)}
@ -1053,6 +1085,14 @@ export function ModelSelector({
Free
</Badge>
) : null}
{isImageIncompatibleChatModel && (
<Badge
variant="secondary"
className="text-[9px] px-1 py-0 h-3.5 bg-amber-100 text-amber-700 dark:bg-amber-900/50 dark:text-amber-300 border-0"
>
No image
</Badge>
)}
</div>
<div className="flex items-center gap-1.5 mt-0.5">
<span className="text-xs text-muted-foreground truncate">

View file

@ -250,8 +250,8 @@ function PricingFAQ() {
Frequently Asked Questions
</h2>
<p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground">
Everything you need to know about SurfSense pages, premium credit, and billing.
Can&apos;t find what you need? Reach out at{" "}
Everything you need to know about SurfSense pages, premium credit, and billing. Can&apos;t
find what you need? Reach out at{" "}
<a href="mailto:rohan@surfsense.com" className="text-blue-500 underline">
rohan@surfsense.com
</a>

View file

@ -22,6 +22,7 @@ import {
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
@ -190,8 +191,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
? "model"
: "models"}
</span>{" "}
available from your administrator.{" "}
{(() => {
available from your administrator. {(() => {
const nonAuto = globalConfigs.filter(
(g) => !("is_auto_mode" in g && g.is_auto_mode)
);
@ -214,6 +214,75 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
</Alert>
)}
{/* Global Image Models read-only cards with per-model Free/Premium
badges. Mirrors the badge palette used by the chat role selector
(`llm-role-manager.tsx`) so the meaning is consistent across
every model-configuration surface (chat / image / vision). */}
{!isLoading &&
globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
<div className="space-y-3">
<h3 className="text-xs md:text-sm font-semibold text-muted-foreground">
Global Image Models
</h3>
<div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3">
{globalConfigs
.filter((g) => !("is_auto_mode" in g && g.is_auto_mode))
.map((cfg) => {
const billingTier =
("billing_tier" in cfg &&
typeof (cfg as { billing_tier?: string }).billing_tier === "string" &&
(cfg as { billing_tier?: string }).billing_tier) ||
"free";
const isPremium = billingTier === "premium";
return (
<Card
key={cfg.id}
className="border-border/60 bg-muted/20 overflow-hidden h-full"
>
<CardContent className="p-4 flex flex-col gap-3 h-full">
<div className="flex items-center gap-2 min-w-0">
<div className="shrink-0">
{getProviderIcon(cfg.provider, { className: "size-4" })}
</div>
<div className="min-w-0 flex-1 flex items-center gap-1.5">
<h4 className="text-sm font-semibold tracking-tight truncate">
{cfg.name}
</h4>
{isPremium ? (
<Badge
variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0"
>
Premium
</Badge>
) : (
<Badge
variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0"
>
Free
</Badge>
)}
</div>
</div>
{cfg.description && (
<p className="text-[11px] text-muted-foreground/70 line-clamp-2">
{cfg.description}
</p>
)}
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
<span className="text-[11px] text-muted-foreground/60 truncate">
{cfg.model_name}
</span>
</div>
</CardContent>
</Card>
);
})}
</div>
</div>
)}
{/* Loading Skeleton */}
{isLoading && (
<div className="space-y-4 md:space-y-6">

View file

@ -70,9 +70,7 @@ export function MorePagesContent() {
<div className="w-full space-y-5">
<div className="text-center">
<h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2>
<p className="mt-1 text-sm text-muted-foreground">
Earn bonus pages by completing tasks
</p>
<p className="mt-1 text-sm text-muted-foreground">Earn bonus pages by completing tasks</p>
</div>
<div className="space-y-2">

View file

@ -22,6 +22,7 @@ import {
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
@ -191,8 +192,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
? "model"
: "models"}
</span>{" "}
available from your administrator.{" "}
{(() => {
available from your administrator. {(() => {
const nonAuto = globalConfigs.filter(
(g) => !("is_auto_mode" in g && g.is_auto_mode)
);
@ -215,6 +215,75 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
</Alert>
)}
{/* Global Vision Models read-only cards with per-model Free/Premium
badges. Mirrors the badge palette used by the chat role selector
(`llm-role-manager.tsx`) so the meaning is consistent across
every model-configuration surface (chat / image / vision). */}
{!isLoading &&
globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
<div className="space-y-3">
<h3 className="text-xs md:text-sm font-semibold text-muted-foreground">
Global Vision Models
</h3>
<div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3">
{globalConfigs
.filter((g) => !("is_auto_mode" in g && g.is_auto_mode))
.map((cfg) => {
const billingTier =
("billing_tier" in cfg &&
typeof (cfg as { billing_tier?: string }).billing_tier === "string" &&
(cfg as { billing_tier?: string }).billing_tier) ||
"free";
const isPremium = billingTier === "premium";
return (
<Card
key={cfg.id}
className="border-border/60 bg-muted/20 overflow-hidden h-full"
>
<CardContent className="p-4 flex flex-col gap-3 h-full">
<div className="flex items-center gap-2 min-w-0">
<div className="shrink-0">
{getProviderIcon(cfg.provider, { className: "size-4" })}
</div>
<div className="min-w-0 flex-1 flex items-center gap-1.5">
<h4 className="text-sm font-semibold tracking-tight truncate">
{cfg.name}
</h4>
{isPremium ? (
<Badge
variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0"
>
Premium
</Badge>
) : (
<Badge
variant="secondary"
className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0"
>
Free
</Badge>
)}
</div>
</div>
{cfg.description && (
<p className="text-[11px] text-muted-foreground/70 line-clamp-2">
{cfg.description}
</p>
)}
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
<span className="text-[11px] text-muted-foreground/60 truncate">
{cfg.model_name}
</span>
</div>
</CardContent>
</Card>
);
})}
</div>
</div>
)}
{isLoading && (
<div className="space-y-4 md:space-y-6">
<div className="space-y-4">

View file

@ -416,9 +416,19 @@ export const GeneratePodcastToolUI = ({
return <PodcastErrorState title={title} error={result.error || "Generation failed"} />;
}
// Already generating - show simple warning, don't create another poller
// The FIRST tool call will display the podcast when ready
// (new: "generating", legacy: "already_generating")
// Pending/generating rows have a stable podcast_id, so the card can poll
// independently while the chat stream finishes.
if (
(result.status === "pending" ||
result.status === "generating" ||
result.status === "processing") &&
result.podcast_id
) {
return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />;
}
// Legacy duplicate/no-ID result - show a simple warning, don't create
// another poller. The first tool call will display the podcast when ready.
if (result.status === "generating" || result.status === "already_generating") {
return (
<div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none">
@ -432,11 +442,6 @@ export const GeneratePodcastToolUI = ({
);
}
// Pending - poll for completion (new: "pending" with podcast_id)
if (result.status === "pending" && result.podcast_id) {
return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />;
}
// Ready with podcast_id (new: "ready", legacy: "success")
if ((result.status === "ready" || result.status === "success") && result.podcast_id) {
return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />;

View file

@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) {
<DialogHeader>
<DialogTitle>Create a free account to {feature}</DialogTitle>
<DialogDescription>
Get $5 of premium credit, save chat history, upload documents, use all AI tools,
and connect 30+ integrations.
Get $5 of premium credit, save chat history, upload documents, use all AI tools, and
connect 30+ integrations.
</DialogDescription>
</DialogHeader>
<DialogFooter className="flex flex-col gap-2 sm:flex-row">

View file

@ -65,6 +65,13 @@ export const newLLMConfig = z.object({
created_at: z.string(),
search_space_id: z.number(),
user_id: z.string(),
// Capability flag — derived server-side at the route boundary from
// LiteLLM's authoritative model map. There is no DB column. Default
// `true` is the conservative-allow stance for unknown / unmapped
// BYOK rows; the streaming-task safety net is the only place a
// `false` actually blocks a request.
supports_image_input: z.boolean().default(true),
});
/**
@ -74,11 +81,16 @@ export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true });
/**
* Create NewLLMConfig
*
* `supports_image_input` is omitted because it is derived server-side
* from LiteLLM's model map at read time there is no DB column to
* persist a client-supplied value into.
*/
export const createNewLLMConfigRequest = newLLMConfig.omit({
id: true,
created_at: true,
user_id: true,
supports_image_input: true,
});
export const createNewLLMConfigResponse = newLLMConfig;
@ -114,6 +126,8 @@ export const updateNewLLMConfigRequest = z.object({
created_at: true,
search_space_id: true,
user_id: true,
// Derived server-side; not part of the writable surface.
supports_image_input: true,
})
.partial(),
});
@ -172,6 +186,16 @@ export const globalNewLLMConfig = z.object({
seo_title: z.string().nullable().optional(),
seo_description: z.string().nullable().optional(),
quota_reserve_tokens: z.number().nullable().optional(),
// Capability flag — true when the model can accept image inputs.
// Resolved server-side (OpenRouter dynamic configs use the OR
// `architecture.input_modalities` field; YAML / BYOK use LiteLLM's
// authoritative `supports_vision` map). The chat selector renders
// an amber "No image" hint when this is false and there are
// pending image attachments, but does not block selection — the
// backend safety net only rejects when LiteLLM *explicitly* marks
// the model as text-only, so unknown / new models still flow
// through. Default `true` matches that conservative-allow stance.
supports_image_input: z.boolean().default(true),
});
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
@ -259,6 +283,9 @@ export const globalImageGenConfig = z.object({
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
billing_tier: z.string().default("free"),
// Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's
// Free/Premium badge logic lights up automatically for image-gen too.
is_premium: z.boolean().default(false),
quota_reserve_micros: z.number().nullable().optional(),
});
@ -341,6 +368,9 @@ export const globalVisionLLMConfig = z.object({
is_global: z.literal(true),
is_auto_mode: z.boolean().optional().default(false),
billing_tier: z.string().default("free"),
// Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's
// Free/Premium badge logic lights up automatically for vision too.
is_premium: z.boolean().default(false),
quota_reserve_tokens: z.number().nullable().optional(),
input_cost_per_token: z.number().nullable().optional(),
output_cost_per_token: z.number().nullable().optional(),

View file

@ -18,6 +18,12 @@ const nextConfig: NextConfig = {
},
images: {
remotePatterns: [
{
protocol: "http",
hostname: "localhost",
port: "8000",
pathname: "/api/v1/image-generations/**",
},
{
protocol: "https",
hostname: "**",