mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-16 21:05:20 +02:00
feat: fixed vision/image provider specific errors and fixed podcast/video streaming
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
This commit is contained in:
parent
ae9d36d77f
commit
47b2994ec7
54 changed files with 4469 additions and 563 deletions
|
|
@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM):
|
|||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction
|
||||
PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"COMETAPI": "cometapi",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
#
|
||||
# Single source of truth lives in
|
||||
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||
# runs during ``app.config`` class-body init) can resolve provider
|
||||
# prefixes without dragging the agent / tools tree into module load
|
||||
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||
# tests) keep working unchanged.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
|
|
@ -178,6 +155,17 @@ class AgentConfig:
|
|||
anonymous_enabled: bool = False
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
# Capability flag: best-effort True for the chat selector / catalog.
|
||||
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||
# is the conservative-allow stance — the streaming-task safety net
|
||||
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||
# actually blocks a request. Setting this to False here without an
|
||||
# authoritative source would silently hide vision-capable models
|
||||
# (the regression we're fixing).
|
||||
supports_image_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_auto_mode(cls) -> "AgentConfig":
|
||||
"""
|
||||
|
|
@ -203,6 +191,12 @@ class AgentConfig:
|
|||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# contains at least one vision-capable deployment; the router
|
||||
# will surface a 404 from a non-vision deployment as a normal
|
||||
# ``allowed_fails`` event and fail over rather than blocking
|
||||
# the request outright.
|
||||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -216,10 +210,24 @@ class AgentConfig:
|
|||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
return cls(
|
||||
provider=config.provider.value
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider),
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
|
|
@ -235,6 +243,16 @@ class AgentConfig:
|
|||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no operator-curated capability flag, so we
|
||||
# ask LiteLLM (default-allow on unknown). The streaming
|
||||
# safety net still blocks if the model is *explicitly*
|
||||
# marked text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -253,15 +271,46 @@ class AgentConfig:
|
|||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
# Get system instructions from YAML, default to empty string
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||
# OpenRouter modalities. The YAML loader already populates this
|
||||
# field, but this method is also called from
|
||||
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||
# so we re-derive here for safety. The bool() coercion preserves
|
||||
# the loader's behaviour for explicit ``true`` / ``false``
|
||||
# strings that PyYAML may surface.
|
||||
if "supports_image_input" in yaml_config:
|
||||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||
else:
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=yaml_config.get("provider", "").upper(),
|
||||
model_name=yaml_config.get("model_name", ""),
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=yaml_config.get("api_key", ""),
|
||||
api_base=yaml_config.get("api_base"),
|
||||
custom_provider=yaml_config.get("custom_provider"),
|
||||
custom_provider=custom_provider,
|
||||
litellm_params=yaml_config.get("litellm_params"),
|
||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||
system_instructions=system_instructions if system_instructions else None,
|
||||
|
|
@ -276,6 +325,7 @@ class AgentConfig:
|
|||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -49,12 +50,16 @@ _PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
prefix = _resolve_provider_prefix(provider, custom_provider)
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
|
|
@ -146,14 +151,18 @@ def create_generate_image_tool(
|
|||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -175,14 +184,18 @@ def create_generate_image_tool(
|
|||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
|
|||
|
|
@ -47,11 +47,37 @@ def load_global_llm_configs():
|
|||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_llm_configs", [])
|
||||
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||
# and matches the `provider_api_base` pattern used elsewhere.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
seen_slugs: dict[str, int] = {}
|
||||
for cfg in configs:
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
cfg.setdefault("anonymous_enabled", False)
|
||||
cfg.setdefault("seo_enabled", False)
|
||||
# Capability flag: explicit YAML override always wins. When the
|
||||
# operator has not annotated the model, defer to LiteLLM's
|
||||
# authoritative model map (`supports_vision`) which already
|
||||
# knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are
|
||||
# vision-capable. Unknown / unmapped models default-allow so
|
||||
# we don't lock the user out of a freshly added third-party
|
||||
# entry; the streaming-task safety net (driven by
|
||||
# `is_known_text_only_chat_model`) is the only place a False
|
||||
# actually blocks a request.
|
||||
if "supports_image_input" not in cfg:
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
cfg["supports_image_input"] = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
if cfg.get("seo_enabled") and cfg.get("seo_slug"):
|
||||
slug = cfg["seo_slug"]
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.signed_image_urls import verify_image_token
|
||||
|
|
@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
"""Resolve the LiteLLM provider prefix used in model strings."""
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
"""Build a litellm model string from provider + model_name."""
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
async def _resolve_billing_for_image_gen(
|
||||
|
|
@ -187,12 +192,18 @@ async def _execute_image_generation(
|
|||
if not cfg:
|
||||
raise ValueError(f"Global image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider")
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -214,12 +225,18 @@ async def _execute_image_generation(
|
|||
if not db_cfg:
|
||||
raise ValueError(f"Image generation config {config_id} not found")
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
@ -277,10 +294,12 @@ async def get_global_image_gen_configs(
|
|||
# Auto mode currently treated as free until per-deployment
|
||||
# billing-tier surfacing lands (see _resolve_billing_for_image_gen).
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -293,7 +312,11 @@ async def get_global_image_gen_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_micros": cfg.get("quota_reserve_micros"),
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from app.schemas import (
|
|||
NewLLMConfigUpdate,
|
||||
)
|
||||
from app.services.llm_service import validate_llm_config
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
|
|
@ -36,6 +37,39 @@ router = APIRouter()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead:
|
||||
"""Augment a BYOK chat config row with the derived ``supports_image_input``.
|
||||
|
||||
There is no DB column for ``supports_image_input`` — the value is
|
||||
resolved at the API boundary from LiteLLM's authoritative model map
|
||||
(default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps
|
||||
the response shape consistent across list / detail / create / update
|
||||
endpoints without having to remember to set the field at every call
|
||||
site.
|
||||
"""
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
)
|
||||
# ``model_validate`` runs the Pydantic conversion using the ORM
|
||||
# attribute access path enabled by ``ConfigDict(from_attributes=True)``,
|
||||
# then we layer the derived field on. ``model_copy(update=...)`` keeps
|
||||
# the surface immutable from the caller's perspective.
|
||||
base_read = NewLLMConfigRead.model_validate(config)
|
||||
return base_read.model_copy(update={"supports_image_input": supports_image_input})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Configs Routes
|
||||
# =============================================================================
|
||||
|
|
@ -84,11 +118,41 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": None,
|
||||
"seo_description": None,
|
||||
"quota_reserve_tokens": None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# includes at least one vision-capable deployment, so
|
||||
# treat Auto as image-capable. The router itself will
|
||||
# still pick a vision-capable deployment for messages
|
||||
# carrying image_url blocks (LiteLLM Router falls back
|
||||
# on ``404`` per its ``allowed_fails`` policy).
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add individual global configs
|
||||
for cfg in global_configs:
|
||||
# Capability resolution: explicit value (YAML override or OR
|
||||
# `_supports_image_input(model)` payload baked in by the
|
||||
# OpenRouter integration service) wins. Fall back to the
|
||||
# LiteLLM-driven helper which default-allows on unknown so
|
||||
# we don't hide vision-capable models that happen to lack a
|
||||
# YAML annotation. The streaming task safety net is the
|
||||
# only place a False ever blocks.
|
||||
if "supports_image_input" in cfg:
|
||||
supports_image_input = bool(cfg.get("supports_image_input"))
|
||||
else:
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
cfg_base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=cfg_base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
safe_config = {
|
||||
"id": cfg.get("id"),
|
||||
"name": cfg.get("name"),
|
||||
|
|
@ -113,6 +177,7 @@ async def get_global_new_llm_configs(
|
|||
"seo_title": cfg.get("seo_title"),
|
||||
"seo_description": cfg.get("seo_description"),
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"supports_image_input": supports_image_input,
|
||||
}
|
||||
safe_configs.append(safe_config)
|
||||
|
||||
|
|
@ -171,7 +236,7 @@ async def create_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(db_config)
|
||||
|
||||
return db_config
|
||||
return _serialize_byok_config(db_config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -213,7 +278,7 @@ async def list_new_llm_configs(
|
|||
.limit(limit)
|
||||
)
|
||||
|
||||
return result.scalars().all()
|
||||
return [_serialize_byok_config(cfg) for cfg in result.scalars().all()]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -268,7 +333,7 @@ async def get_new_llm_config(
|
|||
"You don't have permission to view LLM configurations in this search space",
|
||||
)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -360,7 +425,7 @@ async def update_new_llm_config(
|
|||
await session.commit()
|
||||
await session.refresh(config)
|
||||
|
||||
return config
|
||||
return _serialize_byok_config(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -85,10 +85,12 @@ async def get_global_vision_llm_configs(
|
|||
# Auto mode treated as free until per-deployment billing-tier
|
||||
# surfacing lands; see ``get_vision_llm`` for parity.
|
||||
"billing_tier": "free",
|
||||
"is_premium": False,
|
||||
}
|
||||
)
|
||||
|
||||
for cfg in global_configs:
|
||||
billing_tier = str(cfg.get("billing_tier", "free")).lower()
|
||||
safe_configs.append(
|
||||
{
|
||||
"id": cfg.get("id"),
|
||||
|
|
@ -101,7 +103,11 @@ async def get_global_vision_llm_configs(
|
|||
"api_version": cfg.get("api_version") or None,
|
||||
"litellm_params": cfg.get("litellm_params", {}),
|
||||
"is_global": True,
|
||||
"billing_tier": cfg.get("billing_tier", "free"),
|
||||
"billing_tier": billing_tier,
|
||||
# Mirror chat (``new_llm_config_routes``) so the new-chat
|
||||
# selector's premium badge logic keys off the same
|
||||
# field across chat / image / vision tabs.
|
||||
"is_premium": billing_tier == "premium",
|
||||
"quota_reserve_tokens": cfg.get("quota_reserve_tokens"),
|
||||
"input_cost_per_token": cfg.get("input_cost_per_token"),
|
||||
"output_cost_per_token": cfg.get("output_cost_per_token"),
|
||||
|
|
|
|||
|
|
@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel):
|
|||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_micros: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
|
|
|||
|
|
@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase):
|
|||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (no DB column). Default
|
||||
# True matches the conservative-allow stance — a BYOK row that the
|
||||
# route forgot to augment is not pre-judged. The streaming-task
|
||||
# safety net is the only place a False actually blocks a request.
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map "
|
||||
"(``litellm.supports_vision``) — there is no DB column. "
|
||||
"Default True is the conservative-allow stance for unknown / "
|
||||
"unmapped models."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel):
|
|||
created_at: datetime
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
# Capability flag derived at the API boundary (see NewLLMConfigRead).
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the BYOK chat config can accept image inputs. Derived "
|
||||
"at the route boundary from LiteLLM's authoritative model map. "
|
||||
"Default True is the conservative-allow stance."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
seo_title: str | None = None
|
||||
seo_description: str | None = None
|
||||
quota_reserve_tokens: int | None = None
|
||||
supports_image_input: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the model accepts image inputs (multimodal vision). "
|
||||
"Derived server-side: OpenRouter dynamic configs use "
|
||||
"``architecture.input_modalities``; YAML / BYOK use LiteLLM's "
|
||||
"authoritative model map (``litellm.supports_vision``). The "
|
||||
"new-chat selector hints with a 'No image' badge when this is "
|
||||
"False and there are pending image attachments. The streaming "
|
||||
"task fails fast only when LiteLLM *explicitly* marks a model "
|
||||
"as text-only — unknown / unmapped models default-allow."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel):
|
|||
default="free",
|
||||
description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).",
|
||||
)
|
||||
is_premium: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Convenience boolean derived server-side from "
|
||||
"``billing_tier == 'premium'``. The new-chat model selector "
|
||||
"keys its Free/Premium badge off this field for parity with "
|
||||
"chat (`GlobalLLMConfigRead.is_premium`)."
|
||||
),
|
||||
)
|
||||
quota_reserve_tokens: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
|
|
|||
|
|
@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None:
|
|||
_healthy_until.pop(int(config_id), None)
|
||||
|
||||
|
||||
def _global_candidates() -> list[dict]:
|
||||
def _cfg_supports_image_input(cfg: dict) -> bool:
|
||||
"""True if the global cfg can accept image inputs.
|
||||
|
||||
Prefers the explicit ``supports_image_input`` flag (set by the YAML
|
||||
loader / OpenRouter integration). Falls back to a LiteLLM lookup so
|
||||
a YAML entry whose flag was somehow stripped doesn't get wrongly
|
||||
excluded. Default-allows on unknown — the streaming-task safety net
|
||||
is the actual block, not this filter.
|
||||
"""
|
||||
if "supports_image_input" in cfg:
|
||||
return bool(cfg.get("supports_image_input"))
|
||||
# Lazy import: provider_capabilities -> llm_config -> services chain;
|
||||
# importing at module load would create an init-order cycle through
|
||||
# ``app.config``.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
cfg_litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
cfg_litellm_params.get("base_model")
|
||||
if isinstance(cfg_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
return derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
|
||||
|
||||
def _global_candidates(*, requires_image_input: bool = False) -> list[dict]:
|
||||
"""Return Auto-eligible global cfgs.
|
||||
|
||||
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||
can't be picked as the thread's pin. Also excludes configs currently
|
||||
in runtime cooldown (e.g. temporary 429 bursts).
|
||||
|
||||
When ``requires_image_input`` is True (image turn), additionally
|
||||
filters out configs whose ``supports_image_input`` resolves to False
|
||||
so a text-only deployment can't be pinned for an image request.
|
||||
"""
|
||||
candidates = [
|
||||
cfg
|
||||
|
|
@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]:
|
|||
if _is_usable_global_config(cfg)
|
||||
and not cfg.get("health_gated")
|
||||
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||
and (not requires_image_input or _cfg_supports_image_input(cfg))
|
||||
]
|
||||
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||
|
||||
|
|
@ -237,11 +272,20 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
selected_llm_config_id: int,
|
||||
force_repin_free: bool = False,
|
||||
exclude_config_ids: set[int] | None = None,
|
||||
requires_image_input: bool = False,
|
||||
) -> AutoPinResolution:
|
||||
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||
|
||||
For non-auto selections, this function clears any existing pin and returns
|
||||
the selected id as-is.
|
||||
|
||||
When ``requires_image_input`` is True (the current turn carries an
|
||||
``image_url`` block), the candidate pool is filtered to vision-capable
|
||||
cfgs and any existing pin that can't accept image input is treated as
|
||||
invalid (force re-pin). If no vision-capable cfg is available the
|
||||
function raises ``ValueError`` so the streaming task surfaces the same
|
||||
friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of
|
||||
silently routing the image to a text-only deployment.
|
||||
"""
|
||||
thread = (
|
||||
(
|
||||
|
|
@ -274,14 +318,24 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
|
||||
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||
candidates = [
|
||||
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
|
||||
c
|
||||
for c in _global_candidates(requires_image_input=requires_image_input)
|
||||
if int(c.get("id", 0)) not in excluded_ids
|
||||
]
|
||||
if not candidates:
|
||||
if requires_image_input:
|
||||
# Distinguish the "no vision-capable cfg" case from generic
|
||||
# "no usable cfg" so the streaming task can map this to the
|
||||
# MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error.
|
||||
raise ValueError(
|
||||
"No vision-capable global LLM configs are available for Auto mode"
|
||||
)
|
||||
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||
|
||||
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||
# tier switch), unless the caller explicitly requests a forced repin to free.
|
||||
# tier switch), unless the caller explicitly requests a forced repin to free
|
||||
# *or* the turn requires image input but the pin can't handle it.
|
||||
pinned_id = thread.pinned_llm_config_id
|
||||
if (
|
||||
not force_repin_free
|
||||
|
|
@ -311,6 +365,29 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
from_existing_pin=True,
|
||||
)
|
||||
if pinned_id is not None:
|
||||
# If the pin is *only* invalid because it can't handle the image
|
||||
# turn (it's still a healthy, usable config in the broader pool),
|
||||
# log that explicitly so operators can correlate the re-pin with
|
||||
# the user's image attachment instead of suspecting a cooldown.
|
||||
if requires_image_input:
|
||||
try:
|
||||
pinned_global = next(
|
||||
c
|
||||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if int(c.get("id", 0)) == int(pinned_id)
|
||||
)
|
||||
except StopIteration:
|
||||
pinned_global = None
|
||||
if pinned_global is not None and not _cfg_supports_image_input(
|
||||
pinned_global
|
||||
):
|
||||
logger.info(
|
||||
"auto_pin_repinned_for_image thread_id=%s search_space_id=%s "
|
||||
"previous_config_id=%s",
|
||||
thread_id,
|
||||
search_space_id,
|
||||
pinned_id,
|
||||
)
|
||||
logger.info(
|
||||
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||
thread_id,
|
||||
|
|
@ -327,6 +404,10 @@ async def resolve_or_get_pinned_llm_config_id(
|
|||
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||
|
||||
if not eligible:
|
||||
if requires_image_input:
|
||||
raise ValueError(
|
||||
"Auto mode could not find a vision-capable LLM config for this user and quota state"
|
||||
)
|
||||
raise ValueError(
|
||||
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it.
|
|||
|
||||
KEY DESIGN POINTS (issue A, B):
|
||||
|
||||
1. **Session isolation.** ``billable_call`` takes *no* ``db_session``
|
||||
argument. All ``TokenQuotaService.premium_*`` calls and the audit-row
|
||||
insert each run inside their own ``shielded_async_session()``. This
|
||||
guarantees that a quota commit/rollback can never accidentally flush or
|
||||
roll back rows the caller has staged in the request's main session
|
||||
(e.g. a freshly-created ``ImageGeneration`` row).
|
||||
1. **Session isolation.** ``billable_call`` takes no caller transaction.
|
||||
All ``TokenQuotaService.premium_*`` calls and the audit-row insert run
|
||||
inside their own session context. Route callers use
|
||||
``shielded_async_session()`` by default; Celery callers can provide a
|
||||
worker-loop-safe session factory. This guarantees that quota
|
||||
commit/rollback can never accidentally flush or roll back rows the caller
|
||||
has staged in its main session (e.g. a freshly-created
|
||||
``ImageGeneration`` row).
|
||||
|
||||
2. **ContextVar safety.** The accumulator is scoped via
|
||||
:func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a
|
||||
|
|
@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B):
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
|
@ -58,6 +61,12 @@ from app.services.token_tracking_service import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AUDIT_TIMEOUT_SECONDS = 10.0
|
||||
BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset(
|
||||
{"video_presentation_generation", "podcast_generation"}
|
||||
)
|
||||
BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]]
|
||||
|
||||
|
||||
class QuotaInsufficientError(Exception):
|
||||
"""Raised when ``TokenQuotaService.premium_reserve`` denies a billable
|
||||
|
|
@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception):
|
|||
)
|
||||
|
||||
|
||||
class BillingSettlementError(Exception):
|
||||
"""Raised when a premium call completed but credit settlement failed."""
|
||||
|
||||
def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None:
|
||||
self.usage_type = usage_type
|
||||
self.user_id = user_id
|
||||
super().__init__(
|
||||
f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}"
|
||||
)
|
||||
|
||||
|
||||
async def _rollback_safely(session: AsyncSession) -> None:
|
||||
rollback = getattr(session, "rollback", None)
|
||||
if rollback is not None:
|
||||
with suppress(Exception):
|
||||
await rollback()
|
||||
|
||||
|
||||
async def _record_audit_best_effort(
|
||||
*,
|
||||
session_factory: BillableSessionFactory,
|
||||
usage_type: str,
|
||||
search_space_id: int,
|
||||
user_id: UUID,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
cost_micros: int,
|
||||
model_breakdown: dict[str, Any],
|
||||
call_details: dict[str, Any] | None,
|
||||
thread_id: int | None,
|
||||
message_id: int | None,
|
||||
audit_label: str,
|
||||
timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Persist a TokenUsage row without letting audit failure block callers.
|
||||
|
||||
Premium settlement is mandatory, but TokenUsage is an audit trail. If the
|
||||
audit insert or commit hangs, user-facing artifacts such as videos and
|
||||
podcasts must still be able to transition to READY after settlement.
|
||||
"""
|
||||
audit_thread_id = (
|
||||
None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id
|
||||
)
|
||||
|
||||
async def _persist() -> None:
|
||||
logger.info(
|
||||
"[billable_call] audit start label=%s usage_type=%s user=%s thread=%s "
|
||||
"total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
async with session_factory() as audit_session:
|
||||
try:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_micros=cost_micros,
|
||||
model_breakdown=model_breakdown,
|
||||
call_details=call_details,
|
||||
thread_id=audit_thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
logger.info(
|
||||
"[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
logger.info(
|
||||
"[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
)
|
||||
except BaseException:
|
||||
await _rollback_safely(audit_session)
|
||||
raise
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_persist(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s "
|
||||
"timeout=%.1fs total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
timeout_seconds,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s "
|
||||
"total_tokens=%d cost_micros=%d",
|
||||
audit_label,
|
||||
usage_type,
|
||||
user_id,
|
||||
audit_thread_id,
|
||||
total_tokens,
|
||||
cost_micros,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def billable_call(
|
||||
*,
|
||||
|
|
@ -101,6 +228,8 @@ async def billable_call(
|
|||
thread_id: int | None = None,
|
||||
message_id: int | None = None,
|
||||
call_details: dict[str, Any] | None = None,
|
||||
billable_session_factory: BillableSessionFactory | None = None,
|
||||
audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS,
|
||||
) -> AsyncIterator[TurnTokenAccumulator]:
|
||||
"""Wrap a single billable LLM/image call.
|
||||
|
||||
|
|
@ -124,6 +253,13 @@ async def billable_call(
|
|||
thread_id, message_id: Optional FK columns on ``TokenUsage``.
|
||||
call_details: Optional per-call metadata (model name, parameters)
|
||||
forwarded to ``record_token_usage``.
|
||||
billable_session_factory: Optional async context factory used for
|
||||
reserve/finalize/release/audit sessions. Defaults to
|
||||
``shielded_async_session`` for route callers; Celery callers pass
|
||||
a worker-loop-safe session factory.
|
||||
audit_timeout_seconds: Upper bound for TokenUsage audit persistence.
|
||||
Audit failure is best-effort and does not undo successful
|
||||
settlement.
|
||||
|
||||
Yields:
|
||||
The ``TurnTokenAccumulator`` scoped to this call. The caller invokes
|
||||
|
|
@ -134,6 +270,7 @@ async def billable_call(
|
|||
QuotaInsufficientError: when premium and ``premium_reserve`` denies.
|
||||
"""
|
||||
is_premium = billing_tier == "premium"
|
||||
session_factory = billable_session_factory or shielded_async_session
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
# ---------- Free path: just audit -------------------------------
|
||||
|
|
@ -143,30 +280,22 @@ async def billable_call(
|
|||
finally:
|
||||
# Always audit, even on exception, so we capture cost when
|
||||
# provider returns successfully but the caller raises later.
|
||||
try:
|
||||
async with shielded_async_session() as audit_session:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=acc.total_cost_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] free-path audit insert failed for "
|
||||
"usage_type=%s user_id=%s",
|
||||
usage_type,
|
||||
user_id,
|
||||
)
|
||||
await _record_audit_best_effort(
|
||||
session_factory=session_factory,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=acc.total_cost_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
audit_label="free",
|
||||
timeout_seconds=audit_timeout_seconds,
|
||||
)
|
||||
return
|
||||
|
||||
# ---------- Premium path: reserve → execute → finalize ----------
|
||||
|
|
@ -180,7 +309,7 @@ async def billable_call(
|
|||
|
||||
request_id = str(uuid4())
|
||||
|
||||
async with shielded_async_session() as quota_session:
|
||||
async with session_factory() as quota_session:
|
||||
reserve_result = await TokenQuotaService.premium_reserve(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
|
|
@ -222,7 +351,7 @@ async def billable_call(
|
|||
# from a downstream call, asyncio cancellation, etc.). We use
|
||||
# BaseException so cancellation also releases.
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
async with session_factory() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
|
|
@ -241,7 +370,16 @@ async def billable_call(
|
|||
# ---------- Success: finalize + audit ----------------------------
|
||||
actual_micros = acc.total_cost_micros
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
logger.info(
|
||||
"[billable_call] finalize start user=%s usage_type=%s actual=%d "
|
||||
"reserved=%d thread=%s",
|
||||
user_id,
|
||||
usage_type,
|
||||
actual_micros,
|
||||
reserve_micros,
|
||||
thread_id,
|
||||
)
|
||||
async with session_factory() as quota_session:
|
||||
final_result = await TokenQuotaService.premium_finalize(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
|
|
@ -260,7 +398,7 @@ async def billable_call(
|
|||
final_result.limit,
|
||||
final_result.remaining,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as finalize_exc:
|
||||
# Last-ditch: if finalize itself fails, we must at least release
|
||||
# so the reservation doesn't leak.
|
||||
logger.exception(
|
||||
|
|
@ -269,7 +407,7 @@ async def billable_call(
|
|||
user_id,
|
||||
)
|
||||
try:
|
||||
async with shielded_async_session() as quota_session:
|
||||
async with session_factory() as quota_session:
|
||||
await TokenQuotaService.premium_release(
|
||||
db_session=quota_session,
|
||||
user_id=user_id,
|
||||
|
|
@ -281,31 +419,28 @@ async def billable_call(
|
|||
"for user=%s",
|
||||
user_id,
|
||||
)
|
||||
raise BillingSettlementError(
|
||||
usage_type=usage_type,
|
||||
user_id=user_id,
|
||||
cause=finalize_exc,
|
||||
) from finalize_exc
|
||||
|
||||
try:
|
||||
async with shielded_async_session() as audit_session:
|
||||
await record_token_usage(
|
||||
audit_session,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=actual_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
await audit_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[billable_call] premium-path audit insert failed for "
|
||||
"usage_type=%s user_id=%s (debit was applied)",
|
||||
usage_type,
|
||||
user_id,
|
||||
)
|
||||
await _record_audit_best_effort(
|
||||
session_factory=session_factory,
|
||||
usage_type=usage_type,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
prompt_tokens=acc.total_prompt_tokens,
|
||||
completion_tokens=acc.total_completion_tokens,
|
||||
total_tokens=acc.grand_total,
|
||||
cost_micros=actual_micros,
|
||||
model_breakdown=acc.per_message_summary(),
|
||||
call_details=call_details,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
audit_label="premium",
|
||||
timeout_seconds=audit_timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_agent_billing_for_search_space(
|
||||
|
|
@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space(
|
|||
|
||||
|
||||
__all__ = [
|
||||
"BillingSettlementError",
|
||||
"QuotaInsufficientError",
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
"billable_call",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from typing import Any
|
|||
from litellm import Router
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Special ID for Auto mode - uses router for load balancing
|
||||
|
|
@ -152,12 +154,12 @@ class ImageGenRouterService:
|
|||
return None
|
||||
|
||||
# Build model string
|
||||
provider = config.get("provider", "").upper()
|
||||
if config.get("custom_provider"):
|
||||
model_string = f"{config['custom_provider']}/{config['model_name']}"
|
||||
provider_prefix = config["custom_provider"]
|
||||
else:
|
||||
provider = config.get("provider", "").upper()
|
||||
provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
model_string = f"{provider_prefix}/{config['model_name']}"
|
||||
|
||||
# Build litellm params
|
||||
litellm_params: dict[str, Any] = {
|
||||
|
|
@ -165,9 +167,16 @@ class ImageGenRouterService:
|
|||
"api_key": config.get("api_key"),
|
||||
}
|
||||
|
||||
# Add optional api_base
|
||||
if config.get("api_base"):
|
||||
litellm_params["api_base"] = config["api_base"]
|
||||
# Resolve ``api_base`` so deployments don't silently inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against
|
||||
# the wrong provider (see ``provider_api_base`` docstring).
|
||||
api_base = resolve_api_base(
|
||||
provider=provider,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=config.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
# Add api_version (required for Azure)
|
||||
if config.get("api_version"):
|
||||
|
|
|
|||
|
|
@ -140,8 +140,6 @@ PROVIDER_MAP = {
|
|||
# 404-ing against an inherited Azure endpoint). Re-exported here for
|
||||
# backward compatibility with any external import.
|
||||
from app.services.provider_api_base import ( # noqa: E402
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from app.services.llm_router_service import (
|
|||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import token_tracker
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
|
|
@ -556,22 +557,26 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if global_cfg.get("custom_provider"):
|
||||
model_string = (
|
||||
f"{global_cfg['custom_provider']}/{global_cfg['model_name']}"
|
||||
)
|
||||
provider_prefix = global_cfg["custom_provider"]
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
global_cfg["provider"].upper(),
|
||||
global_cfg["provider"].lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{global_cfg['model_name']}"
|
||||
model_string = f"{provider_prefix}/{global_cfg['model_name']}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": global_cfg["api_key"],
|
||||
}
|
||||
if global_cfg.get("api_base"):
|
||||
litellm_kwargs["api_base"] = global_cfg["api_base"]
|
||||
api_base = resolve_api_base(
|
||||
provider=global_cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=global_cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if global_cfg.get("litellm_params"):
|
||||
litellm_kwargs.update(global_cfg["litellm_params"])
|
||||
|
||||
|
|
@ -606,20 +611,26 @@ async def get_vision_llm(
|
|||
return None
|
||||
|
||||
if vision_cfg.custom_provider:
|
||||
model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}"
|
||||
provider_prefix = vision_cfg.custom_provider
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
else:
|
||||
prefix = VISION_PROVIDER_MAP.get(
|
||||
provider_prefix = VISION_PROVIDER_MAP.get(
|
||||
vision_cfg.provider.value.upper(),
|
||||
vision_cfg.provider.value.lower(),
|
||||
)
|
||||
model_string = f"{prefix}/{vision_cfg.model_name}"
|
||||
model_string = f"{provider_prefix}/{vision_cfg.model_name}"
|
||||
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": vision_cfg.api_key,
|
||||
}
|
||||
if vision_cfg.api_base:
|
||||
litellm_kwargs["api_base"] = vision_cfg.api_base
|
||||
api_base = resolve_api_base(
|
||||
provider=vision_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=vision_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
litellm_kwargs["api_base"] = api_base
|
||||
if vision_cfg.litellm_params:
|
||||
litellm_kwargs.update(vision_cfg.litellm_params)
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool:
|
|||
return "image" in input_mods and "text" in output_mods
|
||||
|
||||
|
||||
def _supports_image_input(model: dict) -> bool:
|
||||
"""Return True if the model accepts ``image`` in its input modalities.
|
||||
|
||||
Differs from :func:`_is_vision_input_model` in that it does NOT
|
||||
require text output — chat-tab models always emit text already (the
|
||||
chat catalog filters by ``_is_text_output_model``), so the only
|
||||
extra capability we need to track per chat config is whether the
|
||||
model can ingest user-attached images. The chat selector and the
|
||||
streaming task both key off this flag to prevent hitting an
|
||||
OpenRouter 404 ``"No endpoints found that support image input"``
|
||||
when the user uploads an image and selects a text-only model
|
||||
(DeepSeek V3, Llama 3.x base, etc.).
|
||||
"""
|
||||
arch = model.get("architecture", {}) or {}
|
||||
input_mods = arch.get("input_modalities", []) or []
|
||||
return "image" in input_mods
|
||||
|
||||
|
||||
def _supports_tool_calling(model: dict) -> bool:
|
||||
"""Return True if the model supports function/tool calling."""
|
||||
supported = model.get("supported_parameters") or []
|
||||
|
|
@ -321,6 +339,13 @@ def _generate_configs(
|
|||
# account-wide quota, so per-deployment routing can't spread load
|
||||
# there — it just drains the shared bucket faster.
|
||||
"router_pool_eligible": tier == "premium",
|
||||
# Capability flag derived from ``architecture.input_modalities``.
|
||||
# Read by the new-chat selector to dim image-incompatible models
|
||||
# when the user has pending image attachments, and by
|
||||
# ``stream_new_chat`` as a fail-fast safety net before the
|
||||
# OpenRouter request would otherwise 404 with
|
||||
# ``"No endpoints found that support image input"``.
|
||||
"supports_image_input": _supports_image_input(model),
|
||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||
# to the static score and gets re-blended with health on the next
|
||||
|
|
@ -398,7 +423,12 @@ def _generate_image_gen_configs(
|
|||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` and 404 on
|
||||
# ``image_generation/transformation`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"litellm_params": dict(litellm_params),
|
||||
|
|
@ -477,7 +507,11 @@ def _generate_vision_llm_configs(
|
|||
"provider": "OPENROUTER",
|
||||
"model_name": model_id,
|
||||
"api_key": api_key,
|
||||
"api_base": "",
|
||||
# Pin to OpenRouter's public base URL so a downstream call site
|
||||
# that forgets ``resolve_api_base`` still doesn't inherit
|
||||
# ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see
|
||||
# ``provider_api_base`` docstring).
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"api_version": None,
|
||||
"rpm": free_rpm if tier == "free" else rpm,
|
||||
"tpm": free_tpm if tier == "free" else tpm,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ source of truth without an inter-service circular import.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
PROVIDER_DEFAULT_API_BASE: dict[str, str] = {
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
|
|
|
|||
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
280
surfsense_backend/app/services/provider_capabilities.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Capability resolution shared by chat / image / vision call sites.
|
||||
|
||||
Why this exists
|
||||
---------------
|
||||
The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a
|
||||
single, authoritative answer to one question: *can this chat config accept
|
||||
``image_url`` content blocks?* Without it, the new-chat selector can't badge
|
||||
incompatible models and the streaming task can't fail fast with a friendly
|
||||
error before sending an image to a text-only provider.
|
||||
|
||||
Two functions, two intents:
|
||||
|
||||
- :func:`derive_supports_image_input` — best-effort *True* for catalog and
|
||||
UI surfacing. Default-allow: an unknown / unmapped model is treated as
|
||||
capable so we never lock the user out of a freshly added or
|
||||
third-party-hosted vision model.
|
||||
|
||||
- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming
|
||||
task's safety net. Returns True only when LiteLLM's model map *explicitly*
|
||||
sets ``supports_vision=False`` (or its bare-name variant does). Anything
|
||||
else — missing key, lookup exception, ``supports_vision=True`` — returns
|
||||
False so the request flows through to the provider.
|
||||
|
||||
Implementation rule: only public LiteLLM symbols
|
||||
------------------------------------------------
|
||||
``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the
|
||||
typed module surface (see ``litellm.__init__`` lazy stubs) and are stable
|
||||
across releases. The private ``_is_explicitly_disabled_factory`` and
|
||||
``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade
|
||||
can't silently break us.
|
||||
|
||||
Why the previous round's strict YAML opt-in flag failed
|
||||
-------------------------------------------------------
|
||||
``supports_image_input: false`` was the YAML loader's setdefault. Operators
|
||||
maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI
|
||||
YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to
|
||||
False and the streaming gate rejected every image turn. Sourcing capability
|
||||
from LiteLLM's authoritative model map (which already says
|
||||
``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Provider-name → LiteLLM model-prefix map.
|
||||
#
|
||||
# Owned here because ``app.services.provider_capabilities`` is the
|
||||
# only edge that's safe to call from ``app.config``'s YAML loader at
|
||||
# class-body init time. ``app.agents.new_chat.llm_config`` re-exports
|
||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||
# map there directly would re-introduce the
|
||||
# ``app.config -> ... -> app.agents.new_chat.tools.generate_image ->
|
||||
# app.config`` cycle that prompted the move.
|
||||
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XAI": "xai",
|
||||
"BEDROCK": "bedrock",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"TOGETHER_AI": "together_ai",
|
||||
"FIREWORKS_AI": "fireworks_ai",
|
||||
"DEEPSEEK": "openai",
|
||||
"ALIBABA_QWEN": "openai",
|
||||
"MOONSHOT": "openai",
|
||||
"ZHIPU": "openai",
|
||||
"GITHUB_MODELS": "github",
|
||||
"REPLICATE": "replicate",
|
||||
"PERPLEXITY": "perplexity",
|
||||
"ANYSCALE": "anyscale",
|
||||
"DEEPINFRA": "deepinfra",
|
||||
"CEREBRAS": "cerebras",
|
||||
"SAMBANOVA": "sambanova",
|
||||
"AI21": "ai21",
|
||||
"CLOUDFLARE": "cloudflare",
|
||||
"DATABRICKS": "databricks",
|
||||
"COMETAPI": "cometapi",
|
||||
"HUGGINGFACE": "huggingface",
|
||||
"MINIMAX": "openai",
|
||||
"CUSTOM": "custom",
|
||||
}
|
||||
|
||||
|
||||
def _candidate_model_strings(
|
||||
*,
|
||||
provider: str | None,
|
||||
model_name: str | None,
|
||||
base_model: str | None,
|
||||
custom_provider: str | None,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates.
|
||||
|
||||
LiteLLM's capability lookup is keyed by ``model`` + (optional)
|
||||
``custom_llm_provider``. Different config sources give us different
|
||||
levels of detail, so we try the most-specific keys first and fall back
|
||||
to bare model names so unannotated entries (e.g. an Azure deployment
|
||||
pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the
|
||||
map. Order matters — the first lookup that returns a definitive answer
|
||||
wins for both helpers.
|
||||
"""
|
||||
candidates: list[tuple[str, str | None]] = []
|
||||
seen: set[tuple[str, str | None]] = set()
|
||||
|
||||
def _add(model: str | None, llm_provider: str | None) -> None:
|
||||
if not model:
|
||||
return
|
||||
key = (model, llm_provider)
|
||||
if key in seen:
|
||||
return
|
||||
seen.add(key)
|
||||
candidates.append(key)
|
||||
|
||||
provider_prefix: str | None = None
|
||||
if provider:
|
||||
provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower())
|
||||
if custom_provider:
|
||||
# ``custom_provider`` overrides everything for CUSTOM/proxy setups.
|
||||
provider_prefix = custom_provider
|
||||
|
||||
primary_model = base_model or model_name
|
||||
bare_model = model_name
|
||||
|
||||
# Most-specific first: provider-prefixed identifier with explicit
|
||||
# custom_llm_provider so LiteLLM won't have to guess the provider via
|
||||
# ``get_llm_provider``.
|
||||
if primary_model and provider_prefix:
|
||||
# e.g. "azure/gpt-5.4" + custom_llm_provider="azure"
|
||||
if "/" in primary_model:
|
||||
_add(primary_model, provider_prefix)
|
||||
else:
|
||||
_add(f"{provider_prefix}/{primary_model}", provider_prefix)
|
||||
|
||||
# Bare base_model (or model_name) with provider hint — handles entries
|
||||
# the upstream map keys without a provider prefix (most ``gpt-*`` and
|
||||
# ``claude-*`` entries do this).
|
||||
if primary_model:
|
||||
_add(primary_model, provider_prefix)
|
||||
|
||||
# Fallback to model_name when base_model differs (e.g. an Azure
|
||||
# deployment whose model_name is the deployment id but base_model is the
|
||||
# canonical OpenAI sku).
|
||||
if bare_model and bare_model != primary_model:
|
||||
if provider_prefix and "/" not in bare_model:
|
||||
_add(f"{provider_prefix}/{bare_model}", provider_prefix)
|
||||
_add(bare_model, provider_prefix)
|
||||
_add(bare_model, None)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def derive_supports_image_input(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
openrouter_input_modalities: Iterable[str] | None = None,
|
||||
) -> bool:
|
||||
"""Best-effort capability flag for the new-chat selector and catalog.
|
||||
|
||||
Resolution order (first definitive answer wins):
|
||||
|
||||
1. ``openrouter_input_modalities`` (when provided as a non-empty
|
||||
iterable). OpenRouter exposes ``architecture.input_modalities`` per
|
||||
model and that's the authoritative source for OR dynamic configs.
|
||||
2. ``litellm.supports_vision`` against each candidate identifier from
|
||||
:func:`_candidate_model_strings`. Returns True as soon as any
|
||||
candidate confirms vision support.
|
||||
3. Default ``True`` — the conservative-allow stance. An unknown /
|
||||
newly-added / third-party-hosted model is *not* pre-judged. The
|
||||
streaming safety net (:func:`is_known_text_only_chat_model`) is the
|
||||
only place a False ever blocks; everywhere else, a False here would
|
||||
just hide a usable model from the user.
|
||||
|
||||
Returns:
|
||||
True if the model can plausibly accept image input, False only when
|
||||
OpenRouter explicitly says it can't.
|
||||
"""
|
||||
if openrouter_input_modalities is not None:
|
||||
modalities = list(openrouter_input_modalities)
|
||||
if modalities:
|
||||
return "image" in modalities
|
||||
# Empty list explicitly published by OR — treat as "no image".
|
||||
return False
|
||||
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
):
|
||||
try:
|
||||
if litellm.supports_vision(
|
||||
model=model_string, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"litellm.supports_vision raised for model=%s provider=%s: %s",
|
||||
model_string,
|
||||
custom_llm_provider,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
# Default-allow. ``is_known_text_only_chat_model`` is the strict gate.
|
||||
return True
|
||||
|
||||
|
||||
def is_known_text_only_chat_model(
|
||||
*,
|
||||
provider: str | None = None,
|
||||
model_name: str | None = None,
|
||||
base_model: str | None = None,
|
||||
custom_provider: str | None = None,
|
||||
) -> bool:
|
||||
"""Strict opt-out probe for the streaming-task safety net.
|
||||
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False`` for at least one candidate identifier. Missing
|
||||
key, lookup exception, or ``supports_vision=True`` all return False so
|
||||
the streaming task lets the request through. This is the inverse-default
|
||||
of :func:`derive_supports_image_input`.
|
||||
|
||||
Why two functions
|
||||
-----------------
|
||||
The selector wants "show me everything that's plausibly capable" —
|
||||
default-allow. The safety net wants "block only when I'm certain it
|
||||
can't" — default-pass. Mixing the two intents in a single function
|
||||
leads to the regression we're fixing here.
|
||||
"""
|
||||
for model_string, custom_llm_provider in _candidate_model_strings(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
):
|
||||
try:
|
||||
info = litellm.get_model_info(
|
||||
model=model_string, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"litellm.get_model_info raised for model=%s provider=%s: %s",
|
||||
model_string,
|
||||
custom_llm_provider,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
# ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision``
|
||||
# may be missing, None, True, or False. We only fire on explicit
|
||||
# False — None / missing / True all mean "don't block".
|
||||
try:
|
||||
value = info.get("supports_vision") # type: ignore[union-attr]
|
||||
except AttributeError:
|
||||
value = None
|
||||
if value is False:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"derive_supports_image_input",
|
||||
"is_known_text_only_chat_model",
|
||||
]
|
||||
|
|
@ -1,10 +1,25 @@
|
|||
"""Celery tasks package."""
|
||||
"""Celery tasks package.
|
||||
|
||||
Also hosts the small helpers every async celery task should use to
|
||||
spin up its event loop. See :func:`run_async_celery_task` for the
|
||||
canonical pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from app.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_celery_engine = None
|
||||
_celery_session_maker = None
|
||||
|
||||
|
|
@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker:
|
|||
_celery_engine, expire_on_commit=False
|
||||
)
|
||||
return _celery_session_maker
|
||||
|
||||
|
||||
def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Drop the shared ``app.db.engine`` connection pool synchronously.
|
||||
|
||||
The shared engine (used by ``shielded_async_session`` and most
|
||||
routes / services) is a module-level singleton with a real pool.
|
||||
Each celery task creates a fresh ``asyncio`` event loop; asyncpg
|
||||
connections cache a reference to whichever loop opened them. When
|
||||
a subsequent task's loop pulls a stale connection from the pool,
|
||||
SQLAlchemy's ``pool_pre_ping`` checkout crashes with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
File ".../asyncio/proactor_events.py", line 402, in _loop_writing
|
||||
self._write_fut = self._loop._proactor.send(self._sock, data)
|
||||
|
||||
or hangs forever inside the asyncpg ``Connection._cancel`` cleanup
|
||||
coroutine that can never run because its loop is gone.
|
||||
|
||||
Disposing the engine forces the pool to drop every cached
|
||||
connection so the next checkout opens a fresh one on the current
|
||||
loop. Safe to call from a task's finally block; failure is logged
|
||||
but never propagated.
|
||||
"""
|
||||
try:
|
||||
from app.db import engine as shared_engine
|
||||
|
||||
loop.run_until_complete(shared_engine.dispose())
|
||||
except Exception:
|
||||
logger.warning("Shared DB engine dispose() failed", exc_info=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T:
|
||||
"""Run an async coroutine inside a fresh event loop with proper
|
||||
DB-engine cleanup.
|
||||
|
||||
This is the canonical entry point for every async celery task.
|
||||
It performs three responsibilities that were previously copy-pasted
|
||||
(incorrectly) across each task module:
|
||||
|
||||
1. Create a fresh ``asyncio`` loop and install it on the current
|
||||
thread (celery's ``--pool=solo`` runs every task on the main
|
||||
thread, but other pool types don't).
|
||||
2. Dispose the shared ``app.db.engine`` BEFORE the task runs so
|
||||
any stale connections left over from a previous task's loop
|
||||
are dropped — defends against tasks that crashed without
|
||||
cleaning up.
|
||||
3. Dispose the shared engine AFTER the task runs so the
|
||||
connections we opened on this loop are released before the
|
||||
loop closes (avoids ``coroutine 'Connection._cancel' was
|
||||
never awaited`` warnings and the next-task hang).
|
||||
|
||||
Use as::
|
||||
|
||||
@celery_app.task(name="my_task", bind=True)
|
||||
def my_task(self, *args):
|
||||
return run_async_celery_task(lambda: _my_task_impl(*args))
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Defense-in-depth: prior task may have crashed before
|
||||
# disposing. Idempotent — no-op if pool is already empty.
|
||||
_dispose_shared_db_engine(loop)
|
||||
return loop.run_until_complete(coro_factory())
|
||||
finally:
|
||||
# Drop any connections this task opened so they don't leak
|
||||
# into the next task's loop.
|
||||
_dispose_shared_db_engine(loop)
|
||||
with contextlib.suppress(Exception):
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
with contextlib.suppress(Exception):
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_celery_session_maker",
|
||||
"run_async_celery_task",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import traceback
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -49,22 +49,15 @@ def index_notion_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Notion pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_notion_pages(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_notion_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_notion_pages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_notion_pages(
|
||||
|
|
@ -95,19 +88,11 @@ def index_github_repos_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index GitHub repositories."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_github_repos(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_github_repos(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_github_repos(
|
||||
|
|
@ -138,19 +123,11 @@ def index_confluence_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Confluence pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_confluence_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_confluence_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_confluence_pages(
|
||||
|
|
@ -181,22 +158,15 @@ def index_google_calendar_events_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Google Calendar events."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_calendar_events(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_calendar_events(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_google_calendar_events", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_google_calendar_events(
|
||||
|
|
@ -227,19 +197,11 @@ def index_google_gmail_messages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Google Gmail messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_gmail_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_gmail_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_google_gmail_messages(
|
||||
|
|
@ -269,22 +231,14 @@ def index_google_drive_files_task(
|
|||
items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options'
|
||||
):
|
||||
"""Celery task to index Google Drive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_google_drive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_google_drive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_google_drive_files(
|
||||
|
|
@ -317,22 +271,14 @@ def index_onedrive_files_task(
|
|||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index OneDrive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_onedrive_files(
|
||||
|
|
@ -365,22 +311,14 @@ def index_dropbox_files_task(
|
|||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index Dropbox folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_dropbox_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_dropbox_files(
|
||||
|
|
@ -414,19 +352,11 @@ def index_elasticsearch_documents_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Elasticsearch documents."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_elasticsearch_documents(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_elasticsearch_documents(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_elasticsearch_documents(
|
||||
|
|
@ -457,22 +387,15 @@ def index_crawled_urls_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Web page Urls."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_crawled_urls(
|
||||
return run_async_celery_task(
|
||||
lambda: _index_crawled_urls(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_crawled_urls", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_crawled_urls(
|
||||
|
|
@ -503,19 +426,11 @@ def index_bookstack_pages_task(
|
|||
end_date: str,
|
||||
):
|
||||
"""Celery task to index BookStack pages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_bookstack_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_bookstack_pages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_bookstack_pages(
|
||||
|
|
@ -546,19 +461,11 @@ def index_composio_connector_task(
|
|||
end_date: str | None,
|
||||
):
|
||||
"""Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio)."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_composio_connector(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_composio_connector(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_composio_connector(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from app.db import Document
|
|||
from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str):
|
|||
document_id: ID of document to reindex
|
||||
user_id: ID of user who edited the document
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reindex_document(document_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(lambda: _reindex_document(document_id, user_id))
|
||||
|
||||
|
||||
async def _reindex_document(document_id: int, user_id: str):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from app.celery_app import celery_app
|
|||
from app.config import config
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
from app.tasks.connector_indexers.local_folder_indexer import (
|
||||
index_local_folder,
|
||||
index_uploaded_files,
|
||||
|
|
@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int):
|
|||
)
|
||||
def delete_document_task(self, document_id: int):
|
||||
"""Celery task to delete a document and its chunks in batches."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_document_background(document_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(lambda: _delete_document_background(document_id))
|
||||
|
||||
|
||||
async def _delete_document_background(document_id: int) -> None:
|
||||
|
|
@ -153,14 +148,9 @@ def delete_folder_documents_task(
|
|||
folder_subtree_ids: list[int] | None = None,
|
||||
):
|
||||
"""Celery task to delete documents first, then the folder rows."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
|
||||
|
||||
async def _delete_folder_documents(
|
||||
|
|
@ -209,12 +199,9 @@ async def _delete_folder_documents(
|
|||
)
|
||||
def delete_search_space_task(self, search_space_id: int):
|
||||
"""Celery task to delete a search space and heavy child rows in batches."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_search_space_background(search_space_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _delete_search_space_background(search_space_id)
|
||||
)
|
||||
|
||||
|
||||
async def _delete_search_space_background(search_space_id: int) -> None:
|
||||
|
|
@ -269,18 +256,11 @@ def process_extension_document_task(
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
# Create a new event loop for this task
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_extension_document(
|
||||
individual_document_dict, search_space_id, user_id
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _process_extension_document(
|
||||
individual_document_dict, search_space_id, user_id
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _process_extension_document(
|
||||
|
|
@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st
|
|||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _process_youtube_video(url, search_space_id, user_id)
|
||||
)
|
||||
|
||||
|
||||
async def _process_youtube_video(url: str, search_space_id: int, user_id: str):
|
||||
|
|
@ -573,12 +549,9 @@ def process_file_upload_task(
|
|||
except Exception as e:
|
||||
logger.warning(f"[process_file_upload] Could not get file size: {e}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_upload(file_path, filename, search_space_id, user_id)
|
||||
run_async_celery_task(
|
||||
lambda: _process_file_upload(file_path, filename, search_space_id, user_id)
|
||||
)
|
||||
logger.info(
|
||||
f"[process_file_upload] Task completed successfully for: {filename}"
|
||||
|
|
@ -589,8 +562,6 @@ def process_file_upload_task(
|
|||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _process_file_upload(
|
||||
|
|
@ -811,25 +782,17 @@ def process_file_upload_with_document_task(
|
|||
"File may have been removed before syncing could start."
|
||||
)
|
||||
# Mark document as failed since file is missing
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_mark_document_failed(
|
||||
document_id,
|
||||
"File not found. Please re-upload the file.",
|
||||
)
|
||||
run_async_celery_task(
|
||||
lambda: _mark_document_failed(
|
||||
document_id,
|
||||
"File not found. Please re-upload the file.",
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
return
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_file_with_document(
|
||||
run_async_celery_task(
|
||||
lambda: _process_file_with_document(
|
||||
document_id,
|
||||
temp_path,
|
||||
filename,
|
||||
|
|
@ -849,8 +812,6 @@ def process_file_upload_with_document_task(
|
|||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_document_failed(document_id: int, reason: str):
|
||||
|
|
@ -1119,22 +1080,16 @@ def process_circleback_meeting_task(
|
|||
search_space_id: ID of the search space
|
||||
connector_id: ID of the Circleback connector (for deletion support)
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_process_circleback_meeting(
|
||||
meeting_id,
|
||||
meeting_name,
|
||||
markdown_content,
|
||||
metadata,
|
||||
search_space_id,
|
||||
connector_id,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _process_circleback_meeting(
|
||||
meeting_id,
|
||||
meeting_name,
|
||||
markdown_content,
|
||||
metadata,
|
||||
search_space_id,
|
||||
connector_id,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _process_circleback_meeting(
|
||||
|
|
@ -1291,25 +1246,19 @@ def index_local_folder_task(
|
|||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_local_folder_async(
|
||||
|
|
@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task(
|
|||
processing_mode: str = "basic",
|
||||
):
|
||||
"""Celery task to index files uploaded from the desktop app."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_uploaded_folder_files_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_name=folder_name,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_uploaded_folder_files_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_name=folder_name,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
file_mappings=file_mappings,
|
||||
use_vision_llm=use_vision_llm,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_uploaded_folder_files_async(
|
||||
|
|
@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str:
|
|||
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
||||
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
||||
"""Full AI sort for all documents in a search space."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _ai_sort_search_space_async(search_space_id, user_id)
|
||||
)
|
||||
|
||||
|
||||
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||
|
|
@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
|||
)
|
||||
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
||||
"""Incremental AI sort for a single document after indexing."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(
|
||||
lambda: _ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
|
||||
|
||||
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.db import SearchSourceConnector
|
||||
from app.schemas.obsidian_plugin import NotePayload
|
||||
from app.services.obsidian_plugin_indexer import upsert_note
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,18 +21,13 @@ def index_obsidian_attachment_task(
|
|||
user_id: str,
|
||||
) -> None:
|
||||
"""Process one Obsidian non-markdown attachment asynchronously."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_obsidian_attachment(
|
||||
connector_id=connector_id,
|
||||
payload_data=payload_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
return run_async_celery_task(
|
||||
lambda: _index_obsidian_attachment(
|
||||
connector_id=connector_id,
|
||||
payload_data=payload_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
)
|
||||
|
||||
|
||||
async def _index_obsidian_attachment(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -12,11 +13,12 @@ from app.celery_app import celery_app
|
|||
from app.config import config as app_config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,6 +36,13 @@ if sys.platform.startswith("win"):
|
|||
# =============================================================================
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _celery_billable_session():
|
||||
"""Session factory used by billable_call inside the Celery worker loop."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||
def generate_content_podcast_task(
|
||||
self,
|
||||
|
|
@ -46,27 +55,22 @@ def generate_content_podcast_task(
|
|||
Celery task to generate podcast from source content.
|
||||
Updates existing podcast record created by the tool.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_content_podcast(
|
||||
return run_async_celery_task(
|
||||
lambda: _generate_content_podcast(
|
||||
podcast_id,
|
||||
source_content,
|
||||
search_space_id,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating content podcast: {e!s}")
|
||||
loop.run_until_complete(_mark_podcast_failed(podcast_id))
|
||||
try:
|
||||
run_async_celery_task(lambda: _mark_podcast_failed(podcast_id))
|
||||
except Exception:
|
||||
logger.exception("Failed to mark podcast %s as failed", podcast_id)
|
||||
return {"status": "failed", "podcast_id": podcast_id}
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_podcast_failed(podcast_id: int) -> None:
|
||||
|
|
@ -148,11 +152,12 @@ async def _generate_content_podcast(
|
|||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=podcast.thread_id,
|
||||
call_details={
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast.title,
|
||||
"thread_id": podcast.thread_id,
|
||||
},
|
||||
billable_session_factory=_celery_billable_session,
|
||||
):
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
|
|
@ -173,6 +178,18 @@ async def _generate_content_podcast(
|
|||
"podcast_id": podcast.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
except BillingSettlementError:
|
||||
logger.exception(
|
||||
"Podcast %s: premium billing settlement failed",
|
||||
podcast.id,
|
||||
)
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"podcast_id": podcast.id,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
|
||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||
file_path = graph_result.get("final_podcast_file_path", "")
|
||||
|
|
@ -194,7 +211,14 @@ async def _generate_content_podcast(
|
|||
podcast.podcast_transcript = serializable_transcript
|
||||
podcast.file_location = file_path
|
||||
podcast.status = PodcastStatus.READY
|
||||
logger.info(
|
||||
"Podcast %s: committing READY transcript_entries=%d file=%s",
|
||||
podcast.id,
|
||||
len(serializable_transcript),
|
||||
file_path,
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Podcast %s: READY commit complete", podcast.id)
|
||||
|
||||
logger.info(f"Successfully generated podcast: {podcast.id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from app.celery_app import celery_app
|
||||
from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
from app.utils.indexing_locks import is_connector_indexing_locked
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -20,15 +20,7 @@ def check_periodic_schedules_task():
|
|||
This task runs every minute and triggers indexing for any connector
|
||||
whose next_scheduled_at time has passed.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_check_and_trigger_schedules())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_check_and_trigger_schedules)
|
||||
|
||||
|
||||
async def _check_and_trigger_schedules():
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from sqlalchemy.future import select
|
|||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Document, DocumentStatus, Notification
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task():
|
|||
Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task.
|
||||
Also marks associated pending/processing documents as failed.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
async def _both() -> None:
|
||||
await _cleanup_stale_notifications()
|
||||
await _cleanup_stale_document_processing_notifications()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_cleanup_stale_notifications())
|
||||
loop.run_until_complete(_cleanup_stale_document_processing_notifications())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_both)
|
||||
|
||||
|
||||
async def _cleanup_stale_notifications():
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
|
|
@ -18,7 +17,7 @@ from app.db import (
|
|||
PremiumTokenPurchaseStatus,
|
||||
)
|
||||
from app.routes import stripe_routes
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None:
|
|||
@celery_app.task(name="reconcile_pending_stripe_page_purchases")
|
||||
def reconcile_pending_stripe_page_purchases_task():
|
||||
"""Recover paid purchases that were left pending due to missed webhook handling."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reconcile_pending_page_purchases())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_reconcile_pending_page_purchases)
|
||||
|
||||
|
||||
async def _reconcile_pending_page_purchases() -> None:
|
||||
|
|
@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None:
|
|||
@celery_app.task(name="reconcile_pending_stripe_token_purchases")
|
||||
def reconcile_pending_stripe_token_purchases_task():
|
||||
"""Recover paid token purchases that were left pending due to missed webhook handling."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_reconcile_pending_token_purchases())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_celery_task(_reconcile_pending_token_purchases)
|
||||
|
||||
|
||||
async def _reconcile_pending_token_purchases() -> None:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -12,11 +13,12 @@ from app.celery_app import celery_app
|
|||
from app.config import config as app_config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,6 +31,13 @@ if sys.platform.startswith("win"):
|
|||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _celery_billable_session():
|
||||
"""Session factory used by billable_call inside the Celery worker loop."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@celery_app.task(name="generate_video_presentation", bind=True)
|
||||
def generate_video_presentation_task(
|
||||
self,
|
||||
|
|
@ -41,27 +50,30 @@ def generate_video_presentation_task(
|
|||
Celery task to generate video presentation from source content.
|
||||
Updates existing video presentation record created by the tool.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_video_presentation(
|
||||
return run_async_celery_task(
|
||||
lambda: _generate_video_presentation(
|
||||
video_presentation_id,
|
||||
source_content,
|
||||
search_space_id,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating video presentation: {e!s}")
|
||||
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
|
||||
# Mark FAILED in a fresh loop — the previous loop is closed.
|
||||
# Swallow secondary failures; the row will simply stay in
|
||||
# GENERATING and be flushed by the periodic stale cleanup.
|
||||
try:
|
||||
run_async_celery_task(
|
||||
lambda: _mark_video_presentation_failed(video_presentation_id)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to mark video presentation %s as failed",
|
||||
video_presentation_id,
|
||||
)
|
||||
return {"status": "failed", "video_presentation_id": video_presentation_id}
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _mark_video_presentation_failed(video_presentation_id: int) -> None:
|
||||
|
|
@ -150,11 +162,12 @@ async def _generate_video_presentation(
|
|||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=video_pres.thread_id,
|
||||
call_details={
|
||||
"video_presentation_id": video_pres.id,
|
||||
"title": video_pres.title,
|
||||
"thread_id": video_pres.thread_id,
|
||||
},
|
||||
billable_session_factory=_celery_billable_session,
|
||||
):
|
||||
graph_result = await video_presentation_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
|
|
@ -175,6 +188,18 @@ async def _generate_video_presentation(
|
|||
"video_presentation_id": video_pres.id,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
except BillingSettlementError:
|
||||
logger.exception(
|
||||
"VideoPresentation %s: premium billing settlement failed",
|
||||
video_pres.id,
|
||||
)
|
||||
video_pres.status = VideoPresentationStatus.FAILED
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"video_presentation_id": video_pres.id,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
|
||||
# Serialize slides (parsed content + audio info merged)
|
||||
slides_raw = graph_result.get("slides", [])
|
||||
|
|
@ -205,7 +230,14 @@ async def _generate_video_presentation(
|
|||
video_pres.slides = serializable_slides
|
||||
video_pres.scene_codes = serializable_scene_codes
|
||||
video_pres.status = VideoPresentationStatus.READY
|
||||
logger.info(
|
||||
"VideoPresentation %s: committing READY slides=%d scene_codes=%d",
|
||||
video_pres.id,
|
||||
len(serializable_slides),
|
||||
len(serializable_scene_codes),
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("VideoPresentation %s: READY commit complete", video_pres.id)
|
||||
|
||||
logger.info(f"Successfully generated video presentation: {video_pres.id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1506,10 +1506,10 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else "Podcast"
|
||||
)
|
||||
if podcast_status == "processing":
|
||||
if podcast_status in ("pending", "generating", "processing"):
|
||||
completed_items = [
|
||||
f"Title: {podcast_title}",
|
||||
"Audio generation started",
|
||||
"Podcast generation started",
|
||||
"Processing in background...",
|
||||
]
|
||||
elif podcast_status == "already_generating":
|
||||
|
|
@ -1518,7 +1518,7 @@ async def _stream_agent_events(
|
|||
"Podcast already in progress",
|
||||
"Please wait for it to complete",
|
||||
]
|
||||
elif podcast_status == "error":
|
||||
elif podcast_status in ("failed", "error"):
|
||||
error_msg = (
|
||||
tool_output.get("error", "Unknown error")
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1528,6 +1528,11 @@ async def _stream_agent_events(
|
|||
f"Title: {podcast_title}",
|
||||
f"Error: {error_msg[:50]}",
|
||||
]
|
||||
elif podcast_status in ("ready", "success"):
|
||||
completed_items = [
|
||||
f"Title: {podcast_title}",
|
||||
"Podcast ready",
|
||||
]
|
||||
else:
|
||||
completed_items = last_active_step_items
|
||||
yield streaming_service.format_thinking_step(
|
||||
|
|
@ -1710,20 +1715,28 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
if (
|
||||
isinstance(tool_output, dict)
|
||||
and tool_output.get("status") == "success"
|
||||
if isinstance(tool_output, dict) and tool_output.get("status") in (
|
||||
"pending",
|
||||
"generating",
|
||||
"processing",
|
||||
):
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast queued: {tool_output.get('title', 'Podcast')}",
|
||||
"success",
|
||||
)
|
||||
elif isinstance(tool_output, dict) and tool_output.get("status") in (
|
||||
"ready",
|
||||
"success",
|
||||
):
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
|
||||
"success",
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
tool_output.get("error", "Unknown error")
|
||||
if isinstance(tool_output, dict)
|
||||
else "Unknown error"
|
||||
)
|
||||
elif isinstance(tool_output, dict) and tool_output.get("status") in (
|
||||
"failed",
|
||||
"error",
|
||||
):
|
||||
error_msg = tool_output.get("error", "Unknown error")
|
||||
yield streaming_service.format_terminal_info(
|
||||
f"Podcast generation failed: {error_msg}",
|
||||
"error",
|
||||
|
|
@ -2292,6 +2305,11 @@ async def stream_new_chat(
|
|||
)
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
# Image-bearing turns force the Auto-pin resolver to filter the
|
||||
# candidate pool to vision-capable cfgs (and force-repin a
|
||||
# text-only existing pin). For explicit selections this flag is
|
||||
# a no-op — the resolver returns the user's chosen id unchanged.
|
||||
_requires_image_input = bool(user_image_data_urls)
|
||||
try:
|
||||
llm_config_id = (
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
|
|
@ -2300,13 +2318,29 @@ async def stream_new_chat(
|
|||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
selected_llm_config_id=llm_config_id,
|
||||
requires_image_input=_requires_image_input,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
# Auto-pin's "no vision-capable cfg" path raises a ValueError
|
||||
# whose message we map to the friendly image-input SSE error
|
||||
# so the user sees the same message regardless of whether
|
||||
# the gate fired in Auto-mode or in the agent_config check
|
||||
# below.
|
||||
error_code = (
|
||||
"MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
if _requires_image_input and "vision-capable" in str(pin_error)
|
||||
else "SERVER_ERROR"
|
||||
)
|
||||
error_kind = (
|
||||
"user_error"
|
||||
if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT"
|
||||
else "server_error"
|
||||
)
|
||||
yield _emit_stream_error(
|
||||
message=str(pin_error),
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
error_kind=error_kind,
|
||||
error_code=error_code,
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
|
@ -2326,6 +2360,50 @@ async def stream_new_chat(
|
|||
llm_config_id,
|
||||
)
|
||||
|
||||
# Capability safety net: a turn carrying user-uploaded images
|
||||
# cannot be routed to a chat config that LiteLLM's authoritative
|
||||
# model map *explicitly* marks as text-only (``supports_vision``
|
||||
# set to False). The check is intentionally narrow — it only
|
||||
# fires when LiteLLM is *certain* the model can't accept image
|
||||
# input. Unknown / unmapped / vision-capable models pass
|
||||
# through. Without this guard a known-text-only model would 404
|
||||
# at the provider with ``"No endpoints found that support image
|
||||
# input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk;
|
||||
# failing here lets us return a friendly message that tells the
|
||||
# user what to change.
|
||||
if user_image_data_urls and agent_config is not None:
|
||||
from app.services.provider_capabilities import (
|
||||
is_known_text_only_chat_model,
|
||||
)
|
||||
|
||||
agent_litellm_params = agent_config.litellm_params or {}
|
||||
agent_base_model = (
|
||||
agent_litellm_params.get("base_model")
|
||||
if isinstance(agent_litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
if is_known_text_only_chat_model(
|
||||
provider=agent_config.provider,
|
||||
model_name=agent_config.model_name,
|
||||
base_model=agent_base_model,
|
||||
custom_provider=agent_config.custom_provider,
|
||||
):
|
||||
model_label = (
|
||||
agent_config.config_name or agent_config.model_name or "model"
|
||||
)
|
||||
yield _emit_stream_error(
|
||||
message=(
|
||||
f"The selected model ({model_label}) does not support "
|
||||
"image input. Switch to a vision-capable model "
|
||||
"(e.g. GPT-4o, Claude, Gemini) or remove the image "
|
||||
"attachment and try again."
|
||||
),
|
||||
error_kind="user_error",
|
||||
error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT",
|
||||
)
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
# Premium quota reservation for pinned premium model only.
|
||||
_needs_premium_quota = (
|
||||
agent_config is not None and user_id and agent_config.is_premium
|
||||
|
|
@ -2366,6 +2444,7 @@ async def stream_new_chat(
|
|||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
force_repin_free=True,
|
||||
requires_image_input=_requires_image_input,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
|
|
@ -2470,6 +2549,7 @@ async def stream_new_chat(
|
|||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
requires_image_input=_requires_image_input,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
except ValueError as pin_error:
|
||||
|
|
@ -2804,6 +2884,7 @@ async def stream_new_chat(
|
|||
from litellm import acompletion
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.services.token_tracking_service import _turn_accumulator
|
||||
|
||||
_turn_accumulator.set(None)
|
||||
|
|
@ -2824,11 +2905,32 @@ async def stream_new_chat(
|
|||
model="auto", messages=messages
|
||||
)
|
||||
else:
|
||||
# Apply the same ``api_base`` cascade chat / vision /
|
||||
# image-gen call sites use so we never inherit
|
||||
# ``litellm.api_base`` (commonly set by
|
||||
# ``AZURE_OPENAI_ENDPOINT``) when the chat config
|
||||
# itself ships an empty ``api_base``. Without this
|
||||
# the title-gen on an OpenRouter chat config would
|
||||
# 404 against the inherited Azure endpoint — see
|
||||
# ``provider_api_base`` docstring for the same
|
||||
# bug repro on the image-gen / vision paths.
|
||||
raw_model = getattr(llm, "model", "") or ""
|
||||
provider_prefix = (
|
||||
raw_model.split("/", 1)[0] if "/" in raw_model else None
|
||||
)
|
||||
provider_value = (
|
||||
agent_config.provider if agent_config is not None else None
|
||||
)
|
||||
title_api_base = resolve_api_base(
|
||||
provider=provider_value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=getattr(llm, "api_base", None),
|
||||
)
|
||||
response = await acompletion(
|
||||
model=llm.model,
|
||||
model=raw_model,
|
||||
messages=messages,
|
||||
api_key=getattr(llm, "api_key", None),
|
||||
api_base=getattr(llm, "api_base", None),
|
||||
api_base=title_api_base,
|
||||
)
|
||||
|
||||
usage_info = None
|
||||
|
|
@ -2953,6 +3055,7 @@ async def stream_new_chat(
|
|||
user_id=user_id,
|
||||
selected_llm_config_id=0,
|
||||
exclude_config_ids={previous_config_id},
|
||||
requires_image_input=_requires_image_input,
|
||||
)
|
||||
).resolved_llm_config_id
|
||||
|
||||
|
|
|
|||
558
surfsense_backend/scripts/verify_chat_image_capability.py
Normal file
558
surfsense_backend/scripts/verify_chat_image_capability.py
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
"""End-to-end smoke test for vision / image config wiring.
|
||||
|
||||
Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and
|
||||
exercises every chat / vision / image-generation config + the OpenRouter
|
||||
dynamic catalog. For each config the script:
|
||||
|
||||
1. Reports the resolver classification (catalog-allow vs strict-block).
|
||||
2. Optionally fires a tiny live API call against the provider:
|
||||
- Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt
|
||||
``"reply with one word: ok"``.
|
||||
- Vision configs: same, against the dedicated vision router pool.
|
||||
- Image-gen configs: ``litellm.aimage_generation`` with a single tiny
|
||||
prompt and ``n=1``.
|
||||
- OpenRouter integration: samples one chat, one vision, one image-gen
|
||||
model from the dynamically fetched catalog.
|
||||
|
||||
Usage::
|
||||
|
||||
python -m scripts.verify_chat_image_capability # capability + connectivity
|
||||
python -m scripts.verify_chat_image_capability --no-live # capability resolver only
|
||||
|
||||
The script is meant to be runnable from the repository root or from
|
||||
``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the
|
||||
end so it's usable as a CI smoke check too.
|
||||
|
||||
Live-mode caveat: each successful call costs a small amount of provider
|
||||
credit (a few tokens or one tiny generated image per config). The
|
||||
default size for image generation is ``1024x1024`` because Azure
|
||||
GPT-image deployments reject smaller sizes; OpenRouter image-gen models
|
||||
generally accept the same size.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
# Bootstrap the surfsense_backend package on sys.path so the script runs
|
||||
# from the repo root or from `surfsense_backend/` interchangeably.
|
||||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
_BACKEND_ROOT = os.path.dirname(_HERE)
|
||||
if _BACKEND_ROOT not in sys.path:
|
||||
sys.path.insert(0, _BACKEND_ROOT)
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
from app.config import config # noqa: E402
|
||||
from app.services.openrouter_integration_service import ( # noqa: E402
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base # noqa: E402
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.WARNING,
|
||||
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
)
|
||||
# Quiet down LiteLLM's verbose router/cost logs so the script output is
|
||||
# scannable.
|
||||
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
|
||||
logging.getLogger("litellm").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
|
||||
# 1x1 transparent PNG — used as the cheapest possible vision payload.
|
||||
_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result accounting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
label: str
|
||||
surface: str
|
||||
config_id: int | str
|
||||
capability_ok: bool | None = None
|
||||
capability_note: str = ""
|
||||
live_ok: bool | None = None
|
||||
live_note: str = ""
|
||||
duration_s: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Report:
|
||||
results: list[ProbeResult] = field(default_factory=list)
|
||||
|
||||
def add(self, r: ProbeResult) -> None:
|
||||
self.results.append(r)
|
||||
|
||||
def render(self) -> int:
|
||||
passed = failed = skipped = 0
|
||||
print()
|
||||
print("=" * 92)
|
||||
print(
|
||||
f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes"
|
||||
)
|
||||
print("-" * 92)
|
||||
for r in self.results:
|
||||
|
||||
def _flag(value: bool | None) -> str:
|
||||
if value is None:
|
||||
return "skip"
|
||||
return "ok" if value else "fail"
|
||||
|
||||
cap = _flag(r.capability_ok)
|
||||
live = _flag(r.live_ok)
|
||||
if r.capability_ok is False or r.live_ok is False:
|
||||
failed += 1
|
||||
elif r.capability_ok is None and r.live_ok is None:
|
||||
skipped += 1
|
||||
else:
|
||||
passed += 1
|
||||
print(
|
||||
f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} "
|
||||
f"{r.duration_s:>5.2f}s {r.label}"
|
||||
)
|
||||
if r.capability_note:
|
||||
print(f" cap: {r.capability_note}")
|
||||
if r.live_note:
|
||||
print(f" live: {r.live_note}")
|
||||
print("-" * 92)
|
||||
print(
|
||||
f"Total: {passed} ok / {failed} fail / {skipped} skip "
|
||||
f"(of {len(self.results)} probes)"
|
||||
)
|
||||
print("=" * 92)
|
||||
return failed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capability probes (no network)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _probe_chat_capability(cfg: dict) -> tuple[bool, str]:
|
||||
"""For chat configs the catalog flag is *expected* True (vision-capable
|
||||
pool). The probe reports both the resolver value and the strict
|
||||
safety-net value to surface any drift between them."""
|
||||
litellm_params = cfg.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model") if isinstance(litellm_params, dict) else None
|
||||
)
|
||||
cap = derive_supports_image_input(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
block = is_known_text_only_chat_model(
|
||||
provider=cfg.get("provider"),
|
||||
model_name=cfg.get("model_name"),
|
||||
base_model=base_model,
|
||||
custom_provider=cfg.get("custom_provider"),
|
||||
)
|
||||
note = f"derive={cap} strict_block={block}"
|
||||
if not cap and not block:
|
||||
# Resolver said False but strict gate is also False — that means
|
||||
# OR modalities published [text] explicitly. Surface it.
|
||||
note += " (OR modality says text-only)"
|
||||
# We accept a True derive *or* (False derive AND False block) as
|
||||
# 'capability ok' — either way, the streaming task will flow through.
|
||||
ok = cap or not block
|
||||
return ok, note
|
||||
|
||||
|
||||
def _build_chat_model_string(cfg: dict) -> str:
|
||||
if cfg.get("custom_provider"):
|
||||
return f"{cfg['custom_provider']}/{cfg['model_name']}"
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
return f"{prefix}/{cfg['model_name']}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live probes (network calls)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Send a 1x1 PNG + `reply with one word: ok` to the chat config."""
|
||||
model_string = _build_chat_model_string(cfg)
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=model_string.split("/", 1)[0],
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "reply with one word: ok"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": _TINY_PNG_DATA_URL},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 16,
|
||||
"timeout": 60,
|
||||
}
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if cfg.get("litellm_params"):
|
||||
# Strip pricing keys — they're tracking-only and confuse some
|
||||
# provider validators (e.g. azure/openai reject unknown kwargs
|
||||
# in strict mode).
|
||||
merged = {
|
||||
k: v
|
||||
for k, v in dict(cfg["litellm_params"]).items()
|
||||
if k
|
||||
not in {
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"input_cost_per_pixel",
|
||||
"output_cost_per_pixel",
|
||||
}
|
||||
}
|
||||
kwargs.update(merged)
|
||||
try:
|
||||
resp = await litellm.acompletion(**kwargs)
|
||||
except Exception as exc:
|
||||
return False, f"{type(exc).__name__}: {exc}"
|
||||
text = resp.choices[0].message.content if resp.choices else ""
|
||||
return True, f"got reply ({(text or '').strip()[:40]!r})"
|
||||
|
||||
|
||||
# Gemini image models occasionally return zero-length ``data`` for the
|
||||
# minimal "red dot on white" prompt (provider-side safety / empty-output
|
||||
# quirk reproducible against ``google/gemini-2.5-flash-image`` even when
|
||||
# the request itself succeeds). Use a more naturalistic prompt and
|
||||
# retry once with a different one before giving up.
|
||||
_IMAGE_GEN_PROMPTS: tuple[str, ...] = (
|
||||
"A simple icon of a coffee cup, flat illustration",
|
||||
"A small green leaf on a white background",
|
||||
)
|
||||
|
||||
|
||||
async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]:
|
||||
"""Generate one tiny image to verify the deployment is reachable."""
|
||||
from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP
|
||||
|
||||
if cfg.get("custom_provider"):
|
||||
prefix = cfg["custom_provider"]
|
||||
else:
|
||||
prefix = _PROVIDER_PREFIX_MAP.get(
|
||||
(cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower()
|
||||
)
|
||||
model_string = f"{prefix}/{cfg['model_name']}"
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=prefix,
|
||||
config_api_base=cfg.get("api_base") or None,
|
||||
)
|
||||
base_kwargs: dict[str, Any] = {
|
||||
"model": model_string,
|
||||
"api_key": cfg.get("api_key"),
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"timeout": 120,
|
||||
}
|
||||
if api_base:
|
||||
base_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
base_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
base_kwargs.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in dict(cfg["litellm_params"]).items()
|
||||
if k
|
||||
not in {
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"input_cost_per_pixel",
|
||||
"output_cost_per_pixel",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
last_note = ""
|
||||
for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1):
|
||||
try:
|
||||
resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs)
|
||||
except Exception as exc:
|
||||
last_note = f"{type(exc).__name__}: {exc}"
|
||||
continue
|
||||
data_count = len(getattr(resp, "data", None) or [])
|
||||
if data_count > 0:
|
||||
return True, (
|
||||
f"received {data_count} image(s) on attempt {attempt} "
|
||||
f"(prompt={prompt!r})"
|
||||
)
|
||||
last_note = (
|
||||
f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})"
|
||||
)
|
||||
return False, last_note
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Probe drivers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _is_or_dynamic(cfg: dict) -> bool:
|
||||
return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER))
|
||||
|
||||
|
||||
async def probe_chat_configs(report: Report, *, live: bool) -> None:
|
||||
print("\n[chat configs from global_llm_configs (YAML-static)]")
|
||||
for cfg in config.GLOBAL_LLM_CONFIGS:
|
||||
# Skip OR dynamic entries here — handled in the OR section so
|
||||
# the YAML / OR split stays clear in the report.
|
||||
if _is_or_dynamic(cfg):
|
||||
continue
|
||||
result = ProbeResult(
|
||||
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||
surface="chat-yaml",
|
||||
config_id=cfg.get("id"),
|
||||
)
|
||||
cap_ok, cap_note = _probe_chat_capability(cfg)
|
||||
result.capability_ok = cap_ok
|
||||
result.capability_note = cap_note
|
||||
if live:
|
||||
t0 = time.perf_counter()
|
||||
ok, note = await _live_chat_image_call(cfg)
|
||||
result.live_ok = ok
|
||||
result.live_note = note
|
||||
result.duration_s = time.perf_counter() - t0
|
||||
report.add(result)
|
||||
|
||||
|
||||
async def probe_vision_configs(report: Report, *, live: bool) -> None:
|
||||
print("\n[vision configs from global_vision_llm_configs (YAML-static)]")
|
||||
for cfg in config.GLOBAL_VISION_LLM_CONFIGS:
|
||||
if _is_or_dynamic(cfg):
|
||||
continue
|
||||
result = ProbeResult(
|
||||
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||
surface="vision",
|
||||
config_id=cfg.get("id"),
|
||||
)
|
||||
# For vision configs, capability is implied — they're in the
|
||||
# dedicated vision pool. Run the same resolver to flag any
|
||||
# surprise disagreement.
|
||||
cap_ok, cap_note = _probe_chat_capability(cfg)
|
||||
result.capability_ok = cap_ok
|
||||
result.capability_note = cap_note
|
||||
if live:
|
||||
t0 = time.perf_counter()
|
||||
ok, note = await _live_chat_image_call(cfg)
|
||||
result.live_ok = ok
|
||||
result.live_note = note
|
||||
result.duration_s = time.perf_counter() - t0
|
||||
report.add(result)
|
||||
|
||||
|
||||
async def probe_image_gen_configs(report: Report, *, live: bool) -> None:
|
||||
print(
|
||||
"\n[image generation configs from global_image_generation_configs (YAML-static)]"
|
||||
)
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if _is_or_dynamic(cfg):
|
||||
continue
|
||||
result = ProbeResult(
|
||||
label=str(cfg.get("name") or cfg.get("model_name")),
|
||||
surface="image-gen",
|
||||
config_id=cfg.get("id"),
|
||||
)
|
||||
# Image gen configs don't have a "supports_image_input" flag;
|
||||
# the catalog tracks output, not input. Mark capability as None
|
||||
# (skip) for the report.
|
||||
if live:
|
||||
t0 = time.perf_counter()
|
||||
ok, note = await _live_image_gen_call(cfg)
|
||||
result.live_ok = ok
|
||||
result.live_note = note
|
||||
result.duration_s = time.perf_counter() - t0
|
||||
report.add(result)
|
||||
|
||||
|
||||
async def probe_openrouter_catalog(report: Report, *, live: bool) -> None:
|
||||
"""Sample one chat (vision-capable), one vision, one image-gen model
|
||||
from the live OpenRouter catalogue. Doesn't iterate the full pool
|
||||
(would be hundreds of probes); just validates the integration end-
|
||||
to-end on a representative model from each surface."""
|
||||
print("\n[OpenRouter integration: sampled probes]")
|
||||
settings = config.OPENROUTER_INTEGRATION_SETTINGS
|
||||
if not settings:
|
||||
report.add(
|
||||
ProbeResult(
|
||||
label="OpenRouter integration",
|
||||
surface="openrouter",
|
||||
config_id="settings",
|
||||
capability_ok=None,
|
||||
capability_note="openrouter_integration disabled in YAML — skipping",
|
||||
live_ok=None,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
service = OpenRouterIntegrationService.get_instance()
|
||||
or_chat = [
|
||||
c
|
||||
for c in config.GLOBAL_LLM_CONFIGS
|
||||
if c.get("provider") == "OPENROUTER" and c.get("supports_image_input")
|
||||
]
|
||||
or_vision = [
|
||||
c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||
]
|
||||
or_image_gen = [
|
||||
c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER"
|
||||
]
|
||||
|
||||
# Pick one representative per provider family per surface so a single
|
||||
# broken vendor (e.g. Anthropic key revoked, Google quota exceeded)
|
||||
# surfaces independently of the others. Each needle matches the
|
||||
# OpenRouter ``model_name`` prefix; the first match wins.
|
||||
def _pick_first(pool: list[dict], needle: str) -> dict | None:
|
||||
for c in pool:
|
||||
if (c.get("model_name") or "").lower().startswith(needle):
|
||||
return c
|
||||
return None
|
||||
|
||||
chat_picks = [
|
||||
("or-chat", _pick_first(or_chat, "openai/gpt-4o")),
|
||||
("or-chat", _pick_first(or_chat, "anthropic/claude")),
|
||||
("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")),
|
||||
]
|
||||
vision_picks = [
|
||||
("or-vision", _pick_first(or_vision, "openai/gpt-4o")),
|
||||
("or-vision", _pick_first(or_vision, "anthropic/claude")),
|
||||
("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")),
|
||||
]
|
||||
image_picks = [
|
||||
("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")),
|
||||
# OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*``
|
||||
# / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match
|
||||
# the actual prefix.
|
||||
("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")),
|
||||
]
|
||||
|
||||
print(
|
||||
f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} "
|
||||
f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})"
|
||||
)
|
||||
|
||||
for surface, picked in chat_picks + vision_picks + image_picks:
|
||||
if not picked:
|
||||
report.add(
|
||||
ProbeResult(
|
||||
label=f"<no candidate for {surface}>",
|
||||
surface=surface,
|
||||
config_id="-",
|
||||
capability_ok=None,
|
||||
capability_note="no candidate found in OR catalog",
|
||||
)
|
||||
)
|
||||
continue
|
||||
runner = (
|
||||
_live_image_gen_call if surface == "or-image" else _live_chat_image_call
|
||||
)
|
||||
result = ProbeResult(
|
||||
label=str(picked.get("model_name")),
|
||||
surface=surface,
|
||||
config_id=picked.get("id"),
|
||||
)
|
||||
if surface != "or-image":
|
||||
cap_ok, cap_note = _probe_chat_capability(picked)
|
||||
result.capability_ok = cap_ok
|
||||
result.capability_note = cap_note
|
||||
if live:
|
||||
t0 = time.perf_counter()
|
||||
ok, note = await runner(picked)
|
||||
result.live_ok = ok
|
||||
result.live_note = note
|
||||
result.duration_s = time.perf_counter() - t0
|
||||
report.add(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def main(args: argparse.Namespace) -> int:
|
||||
print("Loaded global configs:")
|
||||
print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries")
|
||||
print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries")
|
||||
print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries")
|
||||
print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}")
|
||||
|
||||
# Initialize the OpenRouter integration so the catalog is populated
|
||||
# (this is what main.py does at startup). It's idempotent.
|
||||
if config.OPENROUTER_INTEGRATION_SETTINGS:
|
||||
try:
|
||||
from app.config import initialize_openrouter_integration
|
||||
|
||||
initialize_openrouter_integration()
|
||||
except Exception as exc:
|
||||
print(f" WARNING: OpenRouter integration init failed: {exc}")
|
||||
|
||||
print(
|
||||
f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}"
|
||||
)
|
||||
|
||||
report = Report()
|
||||
if not args.skip_chat:
|
||||
await probe_chat_configs(report, live=args.live)
|
||||
if not args.skip_vision:
|
||||
await probe_vision_configs(report, live=args.live)
|
||||
if not args.skip_image_gen:
|
||||
await probe_image_gen_configs(report, live=args.live)
|
||||
if not args.skip_openrouter:
|
||||
await probe_openrouter_catalog(report, live=args.live)
|
||||
|
||||
failed = report.render()
|
||||
return 1 if failed else 0
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--no-live",
|
||||
dest="live",
|
||||
action="store_false",
|
||||
help="Skip live API calls — capability resolver only.",
|
||||
)
|
||||
parser.set_defaults(live=True)
|
||||
parser.add_argument("--skip-chat", action="store_true")
|
||||
parser.add_argument("--skip-vision", action="store_true")
|
||||
parser.add_argument("--skip-image-gen", action="store_true")
|
||||
parser.add_argument("--skip-openrouter", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_args()
|
||||
sys.exit(asyncio.run(main(args)))
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
|
||||
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
|
||||
|
||||
There is no DB column for ``supports_image_input`` on
|
||||
``NewLLMConfig`` — the value is resolved at the API boundary by
|
||||
``derive_supports_image_input`` so the new-chat selector / streaming
|
||||
task can read the same field shape regardless of source (BYOK vs YAML
|
||||
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
|
||||
user out of their own model choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _byok_row(
|
||||
*,
|
||||
id_: int,
|
||||
model_name: str,
|
||||
base_model: str | None = None,
|
||||
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
|
||||
custom_provider: str | None = None,
|
||||
) -> object:
|
||||
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
|
||||
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
|
||||
|
||||
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
|
||||
enum validator accepts it — same as the ORM row would carry."""
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
name=f"BYOK-{id_}",
|
||||
description=None,
|
||||
provider=provider,
|
||||
custom_provider=custom_provider,
|
||||
model_name=model_name,
|
||||
api_key="sk-byok",
|
||||
api_base=None,
|
||||
litellm_params={"base_model": base_model} if base_model else None,
|
||||
system_instructions="",
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
created_at=datetime.now(tz=UTC),
|
||||
search_space_id=42,
|
||||
user_id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_byok_known_vision_model_resolves_true():
|
||||
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
|
||||
True. The serialized row carries that value through to the
|
||||
``NewLLMConfigRead`` schema."""
|
||||
row = _byok_row(id_=1, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
assert serialized.id == 1
|
||||
assert serialized.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_serialize_byok_unknown_model_default_allows():
|
||||
"""Unknown / unmapped: default-allow. The streaming-task safety net
|
||||
is the actual block, and it requires LiteLLM to *explicitly* say
|
||||
text-only — so a brand new BYOK model should not be pre-judged."""
|
||||
row = _byok_row(
|
||||
id_=2,
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
provider=LiteLLMProvider.CUSTOM,
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_uses_base_model_when_present():
|
||||
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
|
||||
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
|
||||
helper must consult ``base_model`` first or unrecognised deployment
|
||||
ids would shadow the real capability."""
|
||||
row = _byok_row(
|
||||
id_=3,
|
||||
model_name="my-azure-deployment-id-no-litellm-knows-this",
|
||||
base_model="gpt-4o",
|
||||
provider=LiteLLMProvider.AZURE_OPENAI,
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_returns_pydantic_read_model():
|
||||
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
|
||||
the schema additions are guaranteed to be present in the API
|
||||
surface. This guards against a future regression where someone
|
||||
deletes the augmentation step and falls back to ORM passthrough."""
|
||||
from app.schemas import NewLLMConfigRead
|
||||
|
||||
row = _byok_row(id_=4, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
assert isinstance(serialized, NewLLMConfigRead)
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
"""Unit tests for ``is_premium`` derivation on the global image-gen and
|
||||
vision-LLM list endpoints.
|
||||
|
||||
Chat globals (``GET /global-llm-configs``) already emit
|
||||
``is_premium = (billing_tier == "premium")``. Image and vision did not,
|
||||
which made the new-chat ``model-selector`` render the Free/Premium badge
|
||||
on the Chat tab but skip it on the Image and Vision tabs (the selector
|
||||
keys its badge logic off ``is_premium``). These tests pin parity:
|
||||
|
||||
* YAML free entry → ``is_premium=False``
|
||||
* YAML premium entry → ``is_premium=True``
|
||||
* OpenRouter dynamic premium entry → ``is_premium=True``
|
||||
* Auto stub (always emitted when at least one config is present)
|
||||
→ ``is_premium=False``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_IMAGE_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "DALL-E 3",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "dall-e-3",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-Image 1 (premium)",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -20_001,
|
||||
"name": "google/gemini-2.5-flash-image (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
_VISION_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o Vision",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "Claude 3.5 Sonnet (premium)",
|
||||
"provider": "ANTHROPIC",
|
||||
"model_name": "claude-3-5-sonnet",
|
||||
"api_key": "sk-ant-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -30_001,
|
||||
"name": "openai/gpt-4o (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image generation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
|
||||
"""Each emitted config must carry ``is_premium`` derived server-side
|
||||
from ``billing_tier``. The Auto stub is always free.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub is always emitted when at least one global config exists,
|
||||
# and it must always declare itself free (Auto-mode billing-tier
|
||||
# surfacing is a separate follow-up).
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
# YAML free entry — ``is_premium=False``
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
# YAML premium entry — ``is_premium=True``
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
# OpenRouter dynamic premium entry — same field, same derivation
|
||||
assert by_id[-20_001]["is_premium"] is True
|
||||
assert by_id[-20_001]["billing_tier"] == "premium"
|
||||
|
||||
# Every emitted dict (including Auto) must have the field — never missing.
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
"""When there are no global configs at all, the endpoint emits an
|
||||
empty list (no Auto stub) — Auto mode would have nothing to route to.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
assert payload == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
assert by_id[-30_001]["is_premium"] is True
|
||||
assert by_id[-30_001]["billing_tier"] == "premium"
|
||||
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
assert payload == []
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on the chat global
|
||||
config endpoint (``GET /global-new-llm-configs``).
|
||||
|
||||
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
|
||||
|
||||
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
|
||||
loader for operator overrides, or by the OpenRouter integration from
|
||||
``architecture.input_modalities``) — wins.
|
||||
2. ``derive_supports_image_input`` helper — default-allow on unknown
|
||||
models, only False when LiteLLM / OR modalities are definitive.
|
||||
|
||||
The flag is purely informational at the API boundary. The streaming
|
||||
task safety net (``is_known_text_only_chat_model``) is the actual block,
|
||||
and it requires LiteLLM to *explicitly* mark the model as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (explicit true)",
|
||||
"description": "vision-capable, explicit YAML override",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "DeepSeek V3 (explicit false)",
|
||||
"description": "OpenRouter dynamic — modality-derived false",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "deepseek/deepseek-v3.2-exp",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": False,
|
||||
},
|
||||
{
|
||||
"id": -10_010,
|
||||
"name": "Unannotated GPT-4o",
|
||||
"description": "no flag set — resolver should derive True via LiteLLM",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
# supports_image_input intentionally absent
|
||||
},
|
||||
{
|
||||
"id": -10_011,
|
||||
"name": "Unannotated unknown model",
|
||||
"description": "unmapped — default-allow True",
|
||||
"provider": "CUSTOM",
|
||||
"custom_provider": "brand_new_proxy",
|
||||
"model_name": "brand-new-model-x9",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
|
||||
"""Each emitted chat config carries ``supports_image_input`` as a
|
||||
bool. Explicit values win; unannotated entries are resolved via the
|
||||
helper (default-allow True)."""
|
||||
from app.config import config
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
|
||||
|
||||
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub: optimistic True so the user can keep Auto selected with
|
||||
# vision-capable deployments somewhere in the pool.
|
||||
assert 0 in by_id, "Auto stub should be emitted when configs exist"
|
||||
assert by_id[0]["supports_image_input"] is True
|
||||
assert by_id[0]["is_auto_mode"] is True
|
||||
|
||||
# Explicit True is preserved.
|
||||
assert by_id[-1]["supports_image_input"] is True
|
||||
|
||||
# Explicit False is preserved (the exact failure mode the safety net
|
||||
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
|
||||
# endpoints found that support image input").
|
||||
assert by_id[-2]["supports_image_input"] is False
|
||||
|
||||
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
|
||||
assert by_id[-10_010]["supports_image_input"] is True
|
||||
|
||||
# Unknown / unmapped model: default-allow rather than pre-judge.
|
||||
assert by_id[-10_011]["supports_image_input"] is True
|
||||
|
||||
for cfg in payload:
|
||||
assert "supports_image_input" in cfg, (
|
||||
f"supports_image_input missing from {cfg.get('id')}"
|
||||
)
|
||||
assert isinstance(cfg["supports_image_input"], bool)
|
||||
|
|
@ -0,0 +1,286 @@
|
|||
"""Image-aware extension of the Auto-pin resolver.
|
||||
|
||||
When the current chat turn carries an ``image_url`` block, the pin
|
||||
resolver must:
|
||||
|
||||
1. Filter the candidate pool to vision-capable cfgs so a freshly
|
||||
selected pin can never be text-only.
|
||||
2. Treat any existing pin whose capability is False as invalid (force
|
||||
re-pin), even when it would otherwise be reused as the thread's
|
||||
stable model.
|
||||
3. Raise ``ValueError`` (mapped to the friendly
|
||||
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
|
||||
task) when no vision-capable cfg is available — instead of silently
|
||||
pinning text-only and 404-ing at the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
clear_healthy,
|
||||
clear_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_caches():
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
self._thread = thread
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _thread(*, pinned: int | None = None):
|
||||
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
|
||||
|
||||
|
||||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"vision-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
"supports_image_input": True,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"text-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
# Higher quality than the vision cfgs — so a bug that ignores
|
||||
# the image flag would surface as the resolver picking this one.
|
||||
"supports_image_input": False,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
async def _premium_allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
# The thread should be pinned to the vision cfg even though the
|
||||
# text-only cfg has a higher quality score.
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
|
||||
"""An existing text-only pin must be invalidated when the next turn
|
||||
requires image input. The non-image path would happily reuse it."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
|
||||
"""If the thread is already pinned to a vision-capable cfg, reuse it
|
||||
— same as the non-image path. Image-aware filtering must not force
|
||||
spurious re-pins."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-2))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
|
||||
"""The friendly-error path: no vision-capable cfg in the pool -> raise
|
||||
``ValueError`` whose message contains ``vision-capable`` so the
|
||||
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _text_only_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="vision-capable"):
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
|
||||
"""Regression guard: the image flag must default False and not affect
|
||||
a normal text-only turn — text-only cfgs remain selectable."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
||||
"""A YAML cfg that omits ``supports_image_input`` falls through to
|
||||
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
|
||||
that returns True, so the cfg should be a valid candidate."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
cfg_unannotated_vision = {
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o", # known vision model in LiteLLM map
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 80,
|
||||
# NOTE: no supports_image_input key
|
||||
}
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
|
|
@ -15,6 +15,7 @@ vision LLM extraction:
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
|
@ -57,6 +58,9 @@ class _FakeSession:
|
|||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -71,7 +75,9 @@ async def _fake_shielded_session():
|
|||
_SESSIONS_USED: list[_FakeSession] = []
|
||||
|
||||
|
||||
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
|
||||
def _patch_isolation_layer(
|
||||
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
|
||||
):
|
||||
"""Wire fake reserve/finalize/release/session helpers."""
|
||||
_SESSIONS_USED.clear()
|
||||
reserve_calls: list[dict[str, Any]] = []
|
||||
|
|
@ -91,6 +97,8 @@ def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None)
|
|||
async def _fake_finalize(
|
||||
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||
):
|
||||
if finalize_exc is not None:
|
||||
raise finalize_exc
|
||||
finalize_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -343,6 +351,125 @@ async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
|||
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
|
||||
from app.services.billable_calls import BillingSettlementError, billable_call
|
||||
|
||||
class _FinalizeError(RuntimeError):
|
||||
pass
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(allowed=True),
|
||||
finalize_exc=_FinalizeError("db finalize failed"),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(BillingSettlementError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _HangingCommitSession(_FakeSession):
|
||||
async def commit(self) -> None:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _hanging_session_factory():
|
||||
s = _HangingCommitSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
billable_session_factory=_hanging_session_factory,
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["release"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_audit_failure_is_best_effort(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
async def _failing_record(_session, **_kwargs):
|
||||
raise RuntimeError("audit insert failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_failing_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
async with billable_call(
|
||||
user_id=uuid4(),
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Podcast / video-presentation usage_type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -387,7 +514,7 @@ async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
|||
assert len(spies["record"]) == 1
|
||||
row = spies["record"][0]
|
||||
assert row["usage_type"] == "podcast_generation"
|
||||
assert row["thread_id"] == 99
|
||||
assert row["thread_id"] is None
|
||||
assert row["search_space_id"] == 42
|
||||
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
"""Defense-in-depth: image-gen call sites must not let an empty
|
||||
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
|
||||
|
||||
The bug repro: an OpenRouter image-gen config ships
|
||||
``api_base=""``. The pre-fix call site in
|
||||
``image_generation_routes._execute_image_generation`` did
|
||||
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
|
||||
silently dropped the empty string. LiteLLM then fell back to
|
||||
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
|
||||
and OpenRouter's ``image_generation/transformation`` appended
|
||||
``/chat/completions`` to it → 404 ``Resource not found``.
|
||||
|
||||
This test pins the post-fix behaviour: with an empty ``api_base`` in
|
||||
the config, the call site MUST set ``api_base`` to OpenRouter's public
|
||||
URL instead of leaving it unset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||
"""The global-config branch (``config_id < 0``) of
|
||||
``_execute_image_generation`` must apply the resolver and pin
|
||||
``api_base`` to OpenRouter when the config ships an empty string.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "", # the original bug shape
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
|
||||
|
||||
image_gen = MagicMock()
|
||||
image_gen.image_generation_config_id = cfg["id"]
|
||||
image_gen.prompt = "test"
|
||||
image_gen.n = 1
|
||||
image_gen.quality = None
|
||||
image_gen.size = None
|
||||
image_gen.style = None
|
||||
image_gen.response_format = None
|
||||
image_gen.model = None
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
session = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_image_gen_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"aimage_generation",
|
||||
side_effect=fake_aimage_generation,
|
||||
),
|
||||
):
|
||||
await image_generation_routes._execute_image_generation(
|
||||
session=session, image_gen=image_gen, search_space=search_space
|
||||
)
|
||||
|
||||
# The whole point of the fix: even with empty ``api_base`` in the
|
||||
# config, we forward OpenRouter's public URL so the call doesn't
|
||||
# inherit an Azure endpoint.
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.new_chat.tools import generate_image as gi_module
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
response = MagicMock()
|
||||
response.model_dump.return_value = {
|
||||
"data": [{"url": "https://example.com/x.png"}]
|
||||
}
|
||||
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
|
||||
return response
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
|
||||
session_cm = AsyncMock()
|
||||
session = AsyncMock()
|
||||
session_cm.__aenter__.return_value = session
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalars.return_value = scalars
|
||||
session.execute.return_value = exec_result
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
|
||||
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
|
||||
async def _refresh(obj):
|
||||
obj.id = 1
|
||||
|
||||
session.refresh.side_effect = _refresh
|
||||
|
||||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
patch.object(
|
||||
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
|
||||
),
|
||||
):
|
||||
tool = gi_module.create_generate_image_tool(
|
||||
search_space_id=1, db_session=MagicMock()
|
||||
)
|
||||
await tool.ainvoke({"prompt": "a cat", "n": 1})
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
|
||||
"""The Auto-mode router pool must also resolve ``api_base`` when an
|
||||
OpenRouter config ships an empty string. The deployment dict is fed
|
||||
straight to ``litellm.Router``, so a missing ``api_base`` would
|
||||
leak the same way as the direct call sites.
|
||||
"""
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
deployment = ImageGenRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
|
@ -265,6 +265,10 @@ def test_generate_image_gen_configs_filters_by_image_output():
|
|||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't 404 against an inherited Azure endpoint.
|
||||
assert c["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
|
|
@ -342,6 +346,10 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
|||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't inherit an Azure endpoint.
|
||||
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
|
|
|
|||
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Unit tests for the shared ``api_base`` resolver.
|
||||
|
||||
The cascade exists so vision and image-gen call sites can't silently
|
||||
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
|
||||
when an OpenRouter / Groq / etc. config ships an empty string. See
|
||||
``provider_api_base`` module docstring for the original repro
|
||||
(OpenRouter image-gen 404-ing against an Azure endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_api_base import (
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_config_value_wins_over_defaults():
|
||||
"""A non-empty config value is always returned verbatim, even when the
|
||||
provider has a default — the operator gets the last word."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="https://my-openrouter-mirror.example.com/v1",
|
||||
)
|
||||
assert result == "https://my-openrouter-mirror.example.com/v1"
|
||||
|
||||
|
||||
def test_provider_key_default_when_config_missing():
|
||||
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
|
||||
base URL — the provider-key map must take precedence over the prefix
|
||||
map so DeepSeek requests don't go to OpenAI."""
|
||||
result = resolve_api_base(
|
||||
provider="DEEPSEEK",
|
||||
provider_prefix="openai",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_provider_prefix_default_when_no_key_default():
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_unknown_provider_returns_none():
|
||||
"""When neither map matches we return ``None`` so the caller can let
|
||||
LiteLLM apply its own provider-integration default (Azure deployment
|
||||
URL, custom-provider URL, etc.)."""
|
||||
result = resolve_api_base(
|
||||
provider="SOMETHING_NEW",
|
||||
provider_prefix="something_new",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_empty_string_config_treated_as_missing():
|
||||
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
|
||||
and downstream call sites use ``if cfg.get("api_base"):`` — empty
|
||||
strings are falsy in Python but the cascade has to step in anyway."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_whitespace_only_config_treated_as_missing():
|
||||
"""A config value of ``" "`` is a configuration mistake — treat it
|
||||
as missing instead of forwarding whitespace to LiteLLM (which would
|
||||
almost certainly 404)."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=" ",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_provider_case_insensitive():
|
||||
"""Some call sites pass the provider lowercase (DB enum value), others
|
||||
uppercase (YAML key). Both must resolve."""
|
||||
upper = resolve_api_base(
|
||||
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
lower = resolve_api_base(
|
||||
provider="deepseek", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_all_inputs_none_returns_none():
|
||||
assert (
|
||||
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
|
||||
is None
|
||||
)
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Unit tests for the shared chat-image capability resolver.
|
||||
|
||||
Two resolvers, two intents:
|
||||
|
||||
- ``derive_supports_image_input`` — best-effort True for the catalog and
|
||||
selector. Default-allow on unknown / unmapped models. The streaming
|
||||
task safety net never sees this value directly.
|
||||
|
||||
- ``is_known_text_only_chat_model`` — strict opt-out for the safety net.
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False``. Anything else (missing key, exception,
|
||||
True) returns False so the request flows through to the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import (
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — OpenRouter modalities path (authoritative)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_or_modalities_with_image_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="openai/gpt-4o",
|
||||
openrouter_input_modalities=["text", "image"],
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_text_only_returns_false():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="deepseek/deepseek-v3.2-exp",
|
||||
openrouter_input_modalities=["text"],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_empty_list_returns_false():
|
||||
"""OR explicitly publishing an empty modality list is a definitive
|
||||
'no inputs at all' signal — treat as False rather than falling back
|
||||
to LiteLLM."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="weird/empty-modalities",
|
||||
openrouter_input_modalities=[],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_none_falls_through_to_litellm():
|
||||
"""``None`` (missing key) is *not* a definitive signal — fall through
|
||||
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
openrouter_input_modalities=None,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — LiteLLM model-map path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_litellm_known_vision_model_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_base_model_wins_over_model_name():
|
||||
"""Azure-style entries pass model_name=deployment_id and put the
|
||||
canonical sku in litellm_params.base_model. The resolver must
|
||||
consult base_model first or the deployment id (which LiteLLM
|
||||
doesn't know) would shadow the real capability."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment-id",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_unknown_model_default_allows():
|
||||
"""Default-allow on unknown — the safety net is the actual block."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_known_text_only_returns_false():
|
||||
"""A model that LiteLLM explicitly knows is text-only resolves to
|
||||
False even via the catalog resolver. ``deepseek-chat`` (the
|
||||
DeepSeek-V3 chat sku) is in the map without supports_vision and
|
||||
LiteLLM's `supports_vision` returns False."""
|
||||
# Sanity: confirm the helper's negative path. We use a small model
|
||||
# known not to support vision per the map.
|
||||
result = derive_supports_image_input(
|
||||
provider="DEEPSEEK",
|
||||
model_name="deepseek-chat",
|
||||
)
|
||||
# We accept either False (LiteLLM said explicit no) or True
|
||||
# (default-allow if the entry isn't mapped on this version) — the
|
||||
# invariant is that the resolver never *raises* on a known-text-only
|
||||
# provider/model. The behaviour-binding assertion lives in
|
||||
# ``test_is_known_text_only_chat_model_explicit_false`` below.
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_known_text_only_chat_model — strict opt-out semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_vision_model():
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_unknown_model():
|
||||
"""Strict opt-out: missing from the map ≠ text-only. The safety net
|
||||
must NOT fire for an unmapped model — that's the regression we're
|
||||
fixing."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
|
||||
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
|
||||
helper swallows the exception and returns False so the safety net
|
||||
doesn't fire on a transient lookup failure."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise ValueError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
|
||||
we exercise the opt-out path deterministically. Using a stub keeps
|
||||
the test stable across LiteLLM map updates."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
|
||||
"""A model entry without ``supports_vision`` at all is treated as
|
||||
'unknown' — strict opt-out means False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"max_input_tokens": 8192} # no supports_vision
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
|
||||
|
||||
Capability is sourced from two places, in order of preference:
|
||||
|
||||
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
|
||||
(authoritative — OpenRouter publishes per-model modalities directly).
|
||||
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
|
||||
YAML / BYOK configs that don't carry an explicit operator override.
|
||||
|
||||
The catalog default is *True* (conservative-allow): an unknown / unmapped
|
||||
model is not pre-judged. The streaming-task safety net
|
||||
(``is_known_text_only_chat_model``) is the only place a False actually
|
||||
blocks a request — and it requires LiteLLM to *explicitly* mark the model
|
||||
as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
_generate_configs,
|
||||
_supports_image_input,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_SETTINGS_BASE: dict = {
|
||||
"api_key": "sk-or-test",
|
||||
"id_offset": -10_000,
|
||||
"rpm": 200,
|
||||
"tpm": 1_000_000,
|
||||
"free_rpm": 20,
|
||||
"free_tpm": 100_000,
|
||||
"anonymous_enabled_paid": False,
|
||||
"anonymous_enabled_free": True,
|
||||
"quota_reserve_tokens": 4000,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_image_input helper (OpenRouter modalities)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_image_input_true_for_multimodal():
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_for_text_only():
|
||||
"""The exact failure mode the safety net guards against — DeepSeek V3
|
||||
is a text-in/text-out model and would 404 if forwarded image_url."""
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_when_modalities_missing():
|
||||
"""Defensive: missing architecture is treated as text-only at the
|
||||
OpenRouter helper level. The wider catalog resolver
|
||||
(`derive_supports_image_input`) only consults modalities when they
|
||||
are non-empty, otherwise it falls back to LiteLLM."""
|
||||
assert _supports_image_input({"id": "weird/model"}) is False
|
||||
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{"id": "weird/model", "architecture": {"input_modalities": None}}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_configs threads the flag onto every emitted chat config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_configs_emits_supports_image_input():
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
by_model = {c["model_name"]: c for c in cfgs}
|
||||
|
||||
gpt = by_model["openai/gpt-4o"]
|
||||
assert gpt["supports_image_input"] is True
|
||||
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
|
||||
assert deepseek["supports_image_input"] is False
|
||||
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML loader: defer to derive_supports_image_input on unannotated entries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
|
||||
"""The regression case: an Azure GPT-5.x YAML entry without a
|
||||
``supports_image_input`` override should resolve to True via LiteLLM's
|
||||
model map (which says ``supports_vision: true``). Previously this
|
||||
defaulted to False, blocking every image turn for vision-capable
|
||||
YAML configs."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -2
|
||||
name: Azure GPT-4o
|
||||
provider: AZURE_OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: GPT-4o
|
||||
provider: OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
supports_image_input: false
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
# Operator override always wins, even against LiteLLM's True.
|
||||
assert configs[0]["supports_image_input"] is False
|
||||
|
||||
|
||||
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
|
||||
"""Unknown / unmapped model in YAML: default-allow. The streaming
|
||||
safety net (which requires an explicit-False from LiteLLM) is the
|
||||
only place a real block happens, so we don't lock the user out of
|
||||
a freshly added third-party entry the catalog can't introspect."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: Some Brand New Model
|
||||
provider: CUSTOM
|
||||
custom_provider: brand_new_proxy
|
||||
model_name: brand-new-model-x9
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentConfig threads the flag through both YAML and Auto / BYOK
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_explicit_overrides_resolver():
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg_text_only = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "Text Only Override",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": False,
|
||||
}
|
||||
)
|
||||
cfg_explicit_vision = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-4o",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
assert cfg_text_only.supports_image_input is False
|
||||
assert cfg_explicit_vision.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_unannotated_uses_resolver():
|
||||
"""Without an explicit YAML key, AgentConfig defers to the catalog
|
||||
resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (no override)",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
}
|
||||
)
|
||||
assert cfg.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_auto_mode_supports_image_input():
|
||||
"""Auto routes across the pool. We optimistically allow image input
|
||||
so users can keep their selection on Auto with a vision-capable
|
||||
deployment somewhere in the pool. The router's own `allowed_fails`
|
||||
handles non-vision deployments via fallback."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
auto = AgentConfig.from_auto_mode()
|
||||
assert auto.supports_image_input is True
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
|
||||
defaults from ``litellm.api_base`` either.
|
||||
|
||||
Vision shares the same shape as image-gen — global YAML / OpenRouter
|
||||
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
|
||||
call sites would silently drop the empty string and inherit
|
||||
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
|
||||
construction so we test the kwargs we hand to it instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch: an OpenRouter vision config with
|
||||
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
|
||||
``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
|
||||
never silently absent."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.user_id = "user-x"
|
||||
search_space.vision_llm_config_id = cfg["id"]
|
||||
|
||||
session = AsyncMock()
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session.execute.return_value = result
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class FakeSanitized:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.vision_llm_router_service.get_global_vision_llm_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch(
|
||||
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
|
||||
new=FakeSanitized,
|
||||
),
|
||||
):
|
||||
await llm_service.get_vision_llm(session=session, search_space_id=1)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-4o"
|
||||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
|
||||
so the resolver has to apply at deployment construction time too."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
|
||||
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
"""Regression tests for ``run_async_celery_task``.
|
||||
|
||||
These tests pin down the production bug observed on 2026-05-02 where
|
||||
the video-presentation Celery task hung at ``[billable_call] finalize``
|
||||
because the shared ``app.db.engine`` had pooled asyncpg connections
|
||||
bound to a *previous* task's now-closed event loop. Reusing such a
|
||||
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
|
||||
(the proactor is None because the loop is gone) and can hang forever
|
||||
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
|
||||
|
||||
The fix is ``run_async_celery_task``: a small helper that runs every
|
||||
async celery task body inside a fresh event loop and disposes the
|
||||
shared engine pool both before (defends against a previous task that
|
||||
crashed) and after (releases connections we opened on this loop).
|
||||
|
||||
Tests here exercise the helper with a stub engine that records
|
||||
``dispose()`` calls and panics if a coroutine produced by one loop is
|
||||
awaited on another — mirroring the real asyncpg behaviour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub engine that emulates the asyncpg-on-stale-loop crash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StaleLoopEngine:
|
||||
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
|
||||
|
||||
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
|
||||
the running event loop id so tests can assert it ran on *each*
|
||||
fresh loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dispose_loop_ids: list[int] = []
|
||||
|
||||
async def dispose(self) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self.dispose_loop_ids.append(id(loop))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
|
||||
"""Patch ``from app.db import engine as shared_engine`` lookup.
|
||||
|
||||
The helper imports lazily inside the function body, so we have to
|
||||
patch the attribute on the already-loaded ``app.db`` module.
|
||||
"""
|
||||
import app.db as app_db
|
||||
|
||||
original = getattr(app_db, "engine", None)
|
||||
app_db.engine = stub # type: ignore[attr-defined]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if original is None:
|
||||
with pytest.raises(AttributeError):
|
||||
_ = app_db.engine
|
||||
else:
|
||||
app_db.engine = original # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
|
||||
"""Happy path: the coroutine result is returned, and the shared
|
||||
engine is disposed both before and after the task body runs.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> str:
|
||||
# Engine should already have been disposed once before we run.
|
||||
assert len(stub.dispose_loop_ids) == 1
|
||||
return "ok"
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
result = run_async_celery_task(_body)
|
||||
|
||||
assert result == "ok"
|
||||
# Once before the body, once after (in finally).
|
||||
assert len(stub.dispose_loop_ids) == 2
|
||||
# Both disposes ran on the SAME (fresh) loop the task body used.
|
||||
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
|
||||
|
||||
|
||||
def test_runner_creates_fresh_loop_per_invocation() -> None:
|
||||
"""Each call must spin its own loop. Without this guarantee a
|
||||
previous task's loop would be reused and the asyncpg-stale-loop
|
||||
crash would never be avoided.
|
||||
"""
|
||||
import app.tasks.celery_tasks as celery_tasks_pkg
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
new_loop_calls = 0
|
||||
closed_loops: list[bool] = []
|
||||
|
||||
real_new_event_loop = asyncio.new_event_loop
|
||||
|
||||
def _counting_new_loop() -> asyncio.AbstractEventLoop:
|
||||
nonlocal new_loop_calls
|
||||
new_loop_calls += 1
|
||||
loop = real_new_event_loop()
|
||||
# Hook close() so we can verify each loop was closed properly
|
||||
# before the next one was created.
|
||||
original_close = loop.close
|
||||
|
||||
def _tracked_close() -> None:
|
||||
closed_loops.append(True)
|
||||
original_close()
|
||||
|
||||
loop.close = _tracked_close # type: ignore[method-assign]
|
||||
return loop
|
||||
|
||||
async def _body() -> None:
|
||||
# Loop is alive and current at body execution time.
|
||||
running = asyncio.get_running_loop()
|
||||
assert not running.is_closed()
|
||||
|
||||
with (
|
||||
_patch_shared_engine(stub),
|
||||
patch.object(asyncio, "new_event_loop", _counting_new_loop),
|
||||
):
|
||||
for _ in range(3):
|
||||
celery_tasks_pkg.run_async_celery_task(_body)
|
||||
|
||||
assert new_loop_calls == 3
|
||||
assert closed_loops == [True, True, True]
|
||||
# Each invocation disposed twice (before + after).
|
||||
assert len(stub.dispose_loop_ids) == 6
|
||||
|
||||
|
||||
def test_runner_disposes_engine_even_when_body_raises() -> None:
|
||||
"""Cleanup MUST run on the failure path too — otherwise stale
|
||||
connections leak into the next task and cause the original hang.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
class _BoomError(RuntimeError):
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
raise _BoomError("kaboom")
|
||||
|
||||
with _patch_shared_engine(stub), pytest.raises(_BoomError):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
|
||||
|
||||
|
||||
def test_runner_swallows_dispose_errors() -> None:
|
||||
"""A flaky engine.dispose() must NEVER take down a celery task.
|
||||
|
||||
Production scenario: the very first dispose (before the body runs)
|
||||
might hit a partially-initialised engine; the helper logs and
|
||||
moves on. The task body still runs; the result is still returned.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
class _AngryEngine:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def dispose(self) -> None:
|
||||
self.calls += 1
|
||||
raise RuntimeError("dispose() blew up")
|
||||
|
||||
stub = _AngryEngine()
|
||||
|
||||
async def _body() -> int:
|
||||
return 42
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
assert run_async_celery_task(_body) == 42
|
||||
|
||||
assert stub.calls == 2 # before + after both attempted
|
||||
|
||||
|
||||
def test_runner_propagates_value_from_async_body() -> None:
|
||||
"""Sanity: pass-through of any pickleable celery return value."""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> dict[str, object]:
|
||||
return {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
out = run_async_celery_task(_body)
|
||||
|
||||
assert out == {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
|
||||
def test_video_presentation_task_uses_runner_helper() -> None:
|
||||
"""Defence-in-depth: confirm the celery task module imports
|
||||
``run_async_celery_task``. If a future refactor inlines a
|
||||
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
|
||||
the original hang will return.
|
||||
"""
|
||||
# The module's task body should not contain a manual new_event_loop
|
||||
# call — that's exactly what the helper exists to centralise.
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
src = inspect.getsource(video_presentation_tasks)
|
||||
assert "run_async_celery_task" in src, (
|
||||
"video_presentation_tasks.py must use run_async_celery_task; "
|
||||
"manual asyncio.new_event_loop() in a celery task hangs on the "
|
||||
"shared SQLAlchemy pool when reused across tasks."
|
||||
)
|
||||
assert "asyncio.new_event_loop" not in src, (
|
||||
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
|
||||
"call — route every async task through run_async_celery_task to "
|
||||
"avoid the stale-pool hang."
|
||||
)
|
||||
|
||||
|
||||
def test_podcast_task_uses_runner_helper() -> None:
|
||||
"""Symmetric assertion for the podcast task — same root cause, same
|
||||
fix, same regression risk.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
src = inspect.getsource(podcast_tasks)
|
||||
assert "run_async_celery_task" in src
|
||||
assert "asyncio.new_event_loop" not in src
|
||||
|
||||
|
||||
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
|
||||
"""If the task body created any async generators that didn't get
|
||||
fully iterated, we must still call ``loop.shutdown_asyncgens()``
|
||||
before closing — otherwise we leak event-loop bound resources
|
||||
that re-emerge as ``RuntimeError: Event loop is closed`` later.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _agen():
|
||||
try:
|
||||
yield 1
|
||||
yield 2
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
# Iterate the agen partially, then leave it dangling — exactly
|
||||
# the situation shutdown_asyncgens() is designed to clean up.
|
||||
async for v in _agen():
|
||||
if v == 1:
|
||||
break
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# By the time the helper returns, garbage collection + shutdown_asyncgens
|
||||
# should have ensured no live async-gen references remain. We don't
|
||||
# assert agen.closed directly (it depends on GC ordering); the real
|
||||
# contract is "no warnings, no event-loop-closed errors". A successful
|
||||
# second invocation proves the loop was cleaned up properly.
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# Force a GC pass to surface any 'coroutine was never awaited'
|
||||
# warnings that would indicate the cleanup is broken.
|
||||
gc.collect()
|
||||
|
||||
|
||||
def test_runner_uses_proactor_loop_on_windows() -> None:
|
||||
"""On Windows the celery worker preselects a Proactor policy so
|
||||
subprocess (ffmpeg) calls work. The helper must not silently fall
|
||||
back to a Selector loop and re-break video/podcast generation.
|
||||
"""
|
||||
if not sys.platform.startswith("win"):
|
||||
pytest.skip("Windows-specific event-loop policy assertion")
|
||||
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
# Mirror the policy set at the top of every Windows celery task.
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
observed: list[str] = []
|
||||
|
||||
async def _body() -> None:
|
||||
observed.append(type(asyncio.get_running_loop()).__name__)
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert observed == ["ProactorEventLoop"]
|
||||
|
|
@ -113,6 +113,19 @@ async def _denying_billable_call(**kwargs):
|
|||
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -187,8 +200,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
|
|||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||
)
|
||||
assert call["thread_id"] == 99
|
||||
assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"podcast_id": 7,
|
||||
"title": "Test Podcast",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -279,6 +299,49 @@ async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypat
|
|||
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=10)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "billable_call", _settlement_failing_billable_call
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=10,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 10,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||
|
|
|
|||
|
|
@ -0,0 +1,119 @@
|
|||
"""Predicate-level test for the chat streaming safety net.
|
||||
|
||||
The safety net in ``stream_new_chat`` rejects an image turn early with
|
||||
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
|
||||
selected model is *known* to be text-only. The earlier round of this
|
||||
work used a strict opt-in flag (``supports_image_input`` defaulting to
|
||||
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
|
||||
deployments — this is the regression we're fixing.
|
||||
|
||||
The new predicate is :func:`is_known_text_only_chat_model`, which
|
||||
returns True only when LiteLLM's authoritative model map *explicitly*
|
||||
sets ``supports_vision=False``. Anything else (vision True, missing
|
||||
key, exception) returns False so the request flows through to the
|
||||
provider.
|
||||
|
||||
We exercise the predicate directly here rather than driving the full
|
||||
``stream_new_chat`` generator — covering the gate in isolation keeps
|
||||
the test focused on the regression while the generator's wider behavior
|
||||
is exercised by the integration suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
||||
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
|
||||
vision-capable per LiteLLM's model map. The previous round's
|
||||
blanket-False default blocked it; the new predicate must NOT mark
|
||||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_unknown_model():
|
||||
"""Default-pass on unknown — the safety net only blocks definitive
|
||||
text-only confirmations. A freshly added third-party model that
|
||||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
||||
"""Transient ``litellm.get_model_info`` exception ≠ block. The
|
||||
helper swallows the error and treats it as 'unknown' → False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM to assert the only path that returns True is the
|
||||
explicit ``supports_vision=False`` case. Anything else (True,
|
||||
None, missing key) returns False from the predicate."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info_explicit_false(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def _info_true(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def _info_missing(**_kwargs):
|
||||
return {"max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -105,6 +105,19 @@ async def _denying_billable_call(**kwargs):
|
|||
yield SimpleNamespace() # pragma: no cover
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -176,11 +189,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
|
|||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
|
||||
)
|
||||
assert call["thread_id"] == 99
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"video_presentation_id": 11,
|
||||
"title": "Test Presentation",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -280,6 +297,57 @@ async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch
|
|||
assert graph_invoked == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=14)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"billable_call",
|
||||
_settlement_failing_billable_call,
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=14,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 14,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
|
|
|
|||
|
|
@ -477,9 +477,7 @@ const MessageInfoDropdown: FC = () => {
|
|||
</span>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{counts.total_tokens.toLocaleString()} tokens
|
||||
{costMicros && costMicros > 0
|
||||
? ` · ${formatTurnCost(costMicros)}`
|
||||
: ""}
|
||||
{costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""}
|
||||
</span>
|
||||
</ActionBarMorePrimitive.Item>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import {
|
|||
import type React from "react";
|
||||
import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
|
||||
import {
|
||||
globalImageGenConfigsAtom,
|
||||
imageGenConfigsAtom,
|
||||
|
|
@ -461,6 +462,18 @@ export function ModelSelector({
|
|||
const { data: visionUserConfigs, isLoading: visionUserLoading } =
|
||||
useAtomValue(visionLLMConfigsAtom);
|
||||
|
||||
// Pending image attachments on the composer. Used to surface an
|
||||
// amber "No image" hint on chat models the catalog reports as
|
||||
// non-vision (`supports_image_input=false`) when the next message
|
||||
// will carry an image. The hint is purely advisory: selection,
|
||||
// focus, and click handling are unaffected. The backend's safety
|
||||
// net (`is_known_text_only_chat_model`) is the actual block, and
|
||||
// it only fires when LiteLLM *explicitly* marks a model as
|
||||
// text-only — so a model that's secretly capable but hasn't been
|
||||
// annotated will still flow through to the provider.
|
||||
const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom);
|
||||
const hasPendingImages = pendingUserImageUrls.length > 0;
|
||||
|
||||
const isLoading =
|
||||
llmUserLoading ||
|
||||
llmGlobalLoading ||
|
||||
|
|
@ -984,6 +997,21 @@ export function ModelSelector({
|
|||
const isSelected = getSelectedId() === config.id;
|
||||
const isFocused = focusedIndex === index;
|
||||
const hasCitations = "citations_enabled" in config && !!config.citations_enabled;
|
||||
// Chat-tab only: surface an amber "No image" hint when the
|
||||
// composer carries images and the catalog reports the model as
|
||||
// non-vision. This is purely advisory — selection is *not*
|
||||
// blocked. The backend's narrow safety net
|
||||
// (`is_known_text_only_chat_model`) is the source of truth for
|
||||
// rejecting image turns, and it only fires when LiteLLM
|
||||
// explicitly marks the model as text-only. A model surfaced as
|
||||
// `supports_image_input=false` here may still be capable in
|
||||
// practice (unknown / unmapped LiteLLM entry), so we let the
|
||||
// user pick it and the provider response decide.
|
||||
const isImageIncompatibleChatModel =
|
||||
activeTab === "llm" &&
|
||||
hasPendingImages &&
|
||||
"supports_image_input" in config &&
|
||||
(config as Record<string, unknown>).supports_image_input === false;
|
||||
|
||||
return (
|
||||
<div
|
||||
|
|
@ -992,6 +1020,11 @@ export function ModelSelector({
|
|||
role="option"
|
||||
tabIndex={isMobile ? -1 : 0}
|
||||
aria-selected={isSelected}
|
||||
title={
|
||||
isImageIncompatibleChatModel
|
||||
? "This model is reported as text-only. You can still pick it; the provider may reject image turns."
|
||||
: undefined
|
||||
}
|
||||
onClick={() => handleSelectItem(item)}
|
||||
onKeyDown={
|
||||
isMobile
|
||||
|
|
@ -1005,9 +1038,8 @@ export function ModelSelector({
|
|||
}
|
||||
onMouseEnter={() => setFocusedIndex(index)}
|
||||
className={cn(
|
||||
"group flex items-center gap-2.5 px-3 py-2 rounded-xl cursor-pointer",
|
||||
"transition-all duration-150 mx-2",
|
||||
"hover:bg-accent/40",
|
||||
"group flex items-center gap-2.5 px-3 py-2 rounded-xl",
|
||||
"transition-all duration-150 mx-2 cursor-pointer hover:bg-accent/40",
|
||||
isSelected && "bg-primary/6 dark:bg-primary/8",
|
||||
isFocused && "bg-accent/50"
|
||||
)}
|
||||
|
|
@ -1053,6 +1085,14 @@ export function ModelSelector({
|
|||
Free
|
||||
</Badge>
|
||||
) : null}
|
||||
{isImageIncompatibleChatModel && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[9px] px-1 py-0 h-3.5 bg-amber-100 text-amber-700 dark:bg-amber-900/50 dark:text-amber-300 border-0"
|
||||
>
|
||||
No image
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-1.5 mt-0.5">
|
||||
<span className="text-xs text-muted-foreground truncate">
|
||||
|
|
|
|||
|
|
@ -250,8 +250,8 @@ function PricingFAQ() {
|
|||
Frequently Asked Questions
|
||||
</h2>
|
||||
<p className="mx-auto mt-4 max-w-2xl text-lg text-muted-foreground">
|
||||
Everything you need to know about SurfSense pages, premium credit, and billing.
|
||||
Can't find what you need? Reach out at{" "}
|
||||
Everything you need to know about SurfSense pages, premium credit, and billing. Can't
|
||||
find what you need? Reach out at{" "}
|
||||
<a href="mailto:rohan@surfsense.com" className="text-blue-500 underline">
|
||||
rohan@surfsense.com
|
||||
</a>
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import {
|
|||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
|
@ -190,8 +191,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
? "model"
|
||||
: "models"}
|
||||
</span>{" "}
|
||||
available from your administrator.{" "}
|
||||
{(() => {
|
||||
available from your administrator. {(() => {
|
||||
const nonAuto = globalConfigs.filter(
|
||||
(g) => !("is_auto_mode" in g && g.is_auto_mode)
|
||||
);
|
||||
|
|
@ -214,6 +214,75 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) {
|
|||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Global Image Models — read-only cards with per-model Free/Premium
|
||||
badges. Mirrors the badge palette used by the chat role selector
|
||||
(`llm-role-manager.tsx`) so the meaning is consistent across
|
||||
every model-configuration surface (chat / image / vision). */}
|
||||
{!isLoading &&
|
||||
globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
|
||||
<div className="space-y-3">
|
||||
<h3 className="text-xs md:text-sm font-semibold text-muted-foreground">
|
||||
Global Image Models
|
||||
</h3>
|
||||
<div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3">
|
||||
{globalConfigs
|
||||
.filter((g) => !("is_auto_mode" in g && g.is_auto_mode))
|
||||
.map((cfg) => {
|
||||
const billingTier =
|
||||
("billing_tier" in cfg &&
|
||||
typeof (cfg as { billing_tier?: string }).billing_tier === "string" &&
|
||||
(cfg as { billing_tier?: string }).billing_tier) ||
|
||||
"free";
|
||||
const isPremium = billingTier === "premium";
|
||||
return (
|
||||
<Card
|
||||
key={cfg.id}
|
||||
className="border-border/60 bg-muted/20 overflow-hidden h-full"
|
||||
>
|
||||
<CardContent className="p-4 flex flex-col gap-3 h-full">
|
||||
<div className="flex items-center gap-2 min-w-0">
|
||||
<div className="shrink-0">
|
||||
{getProviderIcon(cfg.provider, { className: "size-4" })}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1 flex items-center gap-1.5">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{cfg.name}
|
||||
</h4>
|
||||
{isPremium ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0"
|
||||
>
|
||||
Premium
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0"
|
||||
>
|
||||
Free
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{cfg.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 line-clamp-2">
|
||||
{cfg.description}
|
||||
</p>
|
||||
)}
|
||||
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate">
|
||||
{cfg.model_name}
|
||||
</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Loading Skeleton */}
|
||||
{isLoading && (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
|
|
|
|||
|
|
@ -70,9 +70,7 @@ export function MorePagesContent() {
|
|||
<div className="w-full space-y-5">
|
||||
<div className="text-center">
|
||||
<h2 className="text-xl font-bold tracking-tight">Get Free Pages</h2>
|
||||
<p className="mt-1 text-sm text-muted-foreground">
|
||||
Earn bonus pages by completing tasks
|
||||
</p>
|
||||
<p className="mt-1 text-sm text-muted-foreground">Earn bonus pages by completing tasks</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import {
|
|||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
|
@ -191,8 +192,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
? "model"
|
||||
: "models"}
|
||||
</span>{" "}
|
||||
available from your administrator.{" "}
|
||||
{(() => {
|
||||
available from your administrator. {(() => {
|
||||
const nonAuto = globalConfigs.filter(
|
||||
(g) => !("is_auto_mode" in g && g.is_auto_mode)
|
||||
);
|
||||
|
|
@ -215,6 +215,75 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) {
|
|||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Global Vision Models — read-only cards with per-model Free/Premium
|
||||
badges. Mirrors the badge palette used by the chat role selector
|
||||
(`llm-role-manager.tsx`) so the meaning is consistent across
|
||||
every model-configuration surface (chat / image / vision). */}
|
||||
{!isLoading &&
|
||||
globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && (
|
||||
<div className="space-y-3">
|
||||
<h3 className="text-xs md:text-sm font-semibold text-muted-foreground">
|
||||
Global Vision Models
|
||||
</h3>
|
||||
<div className="grid gap-3 grid-cols-1 sm:grid-cols-2 xl:grid-cols-3">
|
||||
{globalConfigs
|
||||
.filter((g) => !("is_auto_mode" in g && g.is_auto_mode))
|
||||
.map((cfg) => {
|
||||
const billingTier =
|
||||
("billing_tier" in cfg &&
|
||||
typeof (cfg as { billing_tier?: string }).billing_tier === "string" &&
|
||||
(cfg as { billing_tier?: string }).billing_tier) ||
|
||||
"free";
|
||||
const isPremium = billingTier === "premium";
|
||||
return (
|
||||
<Card
|
||||
key={cfg.id}
|
||||
className="border-border/60 bg-muted/20 overflow-hidden h-full"
|
||||
>
|
||||
<CardContent className="p-4 flex flex-col gap-3 h-full">
|
||||
<div className="flex items-center gap-2 min-w-0">
|
||||
<div className="shrink-0">
|
||||
{getProviderIcon(cfg.provider, { className: "size-4" })}
|
||||
</div>
|
||||
<div className="min-w-0 flex-1 flex items-center gap-1.5">
|
||||
<h4 className="text-sm font-semibold tracking-tight truncate">
|
||||
{cfg.name}
|
||||
</h4>
|
||||
{isPremium ? (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[8px] md:text-[9px] shrink-0 bg-purple-100 text-purple-700 dark:bg-purple-900/50 dark:text-purple-300 border-0"
|
||||
>
|
||||
Premium
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="text-[8px] md:text-[9px] shrink-0 bg-emerald-100 text-emerald-700 dark:bg-emerald-900/50 dark:text-emerald-300 border-0"
|
||||
>
|
||||
Free
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{cfg.description && (
|
||||
<p className="text-[11px] text-muted-foreground/70 line-clamp-2">
|
||||
{cfg.description}
|
||||
</p>
|
||||
)}
|
||||
<div className="flex items-center pt-2 border-t border-border/40 mt-auto">
|
||||
<span className="text-[11px] text-muted-foreground/60 truncate">
|
||||
{cfg.model_name}
|
||||
</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isLoading && (
|
||||
<div className="space-y-4 md:space-y-6">
|
||||
<div className="space-y-4">
|
||||
|
|
|
|||
|
|
@ -416,9 +416,19 @@ export const GeneratePodcastToolUI = ({
|
|||
return <PodcastErrorState title={title} error={result.error || "Generation failed"} />;
|
||||
}
|
||||
|
||||
// Already generating - show simple warning, don't create another poller
|
||||
// The FIRST tool call will display the podcast when ready
|
||||
// (new: "generating", legacy: "already_generating")
|
||||
// Pending/generating rows have a stable podcast_id, so the card can poll
|
||||
// independently while the chat stream finishes.
|
||||
if (
|
||||
(result.status === "pending" ||
|
||||
result.status === "generating" ||
|
||||
result.status === "processing") &&
|
||||
result.podcast_id
|
||||
) {
|
||||
return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />;
|
||||
}
|
||||
|
||||
// Legacy duplicate/no-ID result - show a simple warning, don't create
|
||||
// another poller. The first tool call will display the podcast when ready.
|
||||
if (result.status === "generating" || result.status === "already_generating") {
|
||||
return (
|
||||
<div className="my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none">
|
||||
|
|
@ -432,11 +442,6 @@ export const GeneratePodcastToolUI = ({
|
|||
);
|
||||
}
|
||||
|
||||
// Pending - poll for completion (new: "pending" with podcast_id)
|
||||
if (result.status === "pending" && result.podcast_id) {
|
||||
return <PodcastStatusPoller podcastId={result.podcast_id} title={result.title || title} />;
|
||||
}
|
||||
|
||||
// Ready with podcast_id (new: "ready", legacy: "success")
|
||||
if ((result.status === "ready" || result.status === "success") && result.podcast_id) {
|
||||
return <PodcastPlayer podcastId={result.podcast_id} title={result.title || title} />;
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) {
|
|||
<DialogHeader>
|
||||
<DialogTitle>Create a free account to {feature}</DialogTitle>
|
||||
<DialogDescription>
|
||||
Get $5 of premium credit, save chat history, upload documents, use all AI tools,
|
||||
and connect 30+ integrations.
|
||||
Get $5 of premium credit, save chat history, upload documents, use all AI tools, and
|
||||
connect 30+ integrations.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<DialogFooter className="flex flex-col gap-2 sm:flex-row">
|
||||
|
|
|
|||
|
|
@ -65,6 +65,13 @@ export const newLLMConfig = z.object({
|
|||
created_at: z.string(),
|
||||
search_space_id: z.number(),
|
||||
user_id: z.string(),
|
||||
|
||||
// Capability flag — derived server-side at the route boundary from
|
||||
// LiteLLM's authoritative model map. There is no DB column. Default
|
||||
// `true` is the conservative-allow stance for unknown / unmapped
|
||||
// BYOK rows; the streaming-task safety net is the only place a
|
||||
// `false` actually blocks a request.
|
||||
supports_image_input: z.boolean().default(true),
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
@ -74,11 +81,16 @@ export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true });
|
|||
|
||||
/**
|
||||
* Create NewLLMConfig
|
||||
*
|
||||
* `supports_image_input` is omitted because it is derived server-side
|
||||
* from LiteLLM's model map at read time — there is no DB column to
|
||||
* persist a client-supplied value into.
|
||||
*/
|
||||
export const createNewLLMConfigRequest = newLLMConfig.omit({
|
||||
id: true,
|
||||
created_at: true,
|
||||
user_id: true,
|
||||
supports_image_input: true,
|
||||
});
|
||||
|
||||
export const createNewLLMConfigResponse = newLLMConfig;
|
||||
|
|
@ -114,6 +126,8 @@ export const updateNewLLMConfigRequest = z.object({
|
|||
created_at: true,
|
||||
search_space_id: true,
|
||||
user_id: true,
|
||||
// Derived server-side; not part of the writable surface.
|
||||
supports_image_input: true,
|
||||
})
|
||||
.partial(),
|
||||
});
|
||||
|
|
@ -172,6 +186,16 @@ export const globalNewLLMConfig = z.object({
|
|||
seo_title: z.string().nullable().optional(),
|
||||
seo_description: z.string().nullable().optional(),
|
||||
quota_reserve_tokens: z.number().nullable().optional(),
|
||||
// Capability flag — true when the model can accept image inputs.
|
||||
// Resolved server-side (OpenRouter dynamic configs use the OR
|
||||
// `architecture.input_modalities` field; YAML / BYOK use LiteLLM's
|
||||
// authoritative `supports_vision` map). The chat selector renders
|
||||
// an amber "No image" hint when this is false and there are
|
||||
// pending image attachments, but does not block selection — the
|
||||
// backend safety net only rejects when LiteLLM *explicitly* marks
|
||||
// the model as text-only, so unknown / new models still flow
|
||||
// through. Default `true` matches that conservative-allow stance.
|
||||
supports_image_input: z.boolean().default(true),
|
||||
});
|
||||
|
||||
export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig);
|
||||
|
|
@ -259,6 +283,9 @@ export const globalImageGenConfig = z.object({
|
|||
is_global: z.literal(true),
|
||||
is_auto_mode: z.boolean().optional().default(false),
|
||||
billing_tier: z.string().default("free"),
|
||||
// Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's
|
||||
// Free/Premium badge logic lights up automatically for image-gen too.
|
||||
is_premium: z.boolean().default(false),
|
||||
quota_reserve_micros: z.number().nullable().optional(),
|
||||
});
|
||||
|
||||
|
|
@ -341,6 +368,9 @@ export const globalVisionLLMConfig = z.object({
|
|||
is_global: z.literal(true),
|
||||
is_auto_mode: z.boolean().optional().default(false),
|
||||
billing_tier: z.string().default("free"),
|
||||
// Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's
|
||||
// Free/Premium badge logic lights up automatically for vision too.
|
||||
is_premium: z.boolean().default(false),
|
||||
quota_reserve_tokens: z.number().nullable().optional(),
|
||||
input_cost_per_token: z.number().nullable().optional(),
|
||||
output_cost_per_token: z.number().nullable().optional(),
|
||||
|
|
|
|||
|
|
@ -18,6 +18,12 @@ const nextConfig: NextConfig = {
|
|||
},
|
||||
images: {
|
||||
remotePatterns: [
|
||||
{
|
||||
protocol: "http",
|
||||
hostname: "localhost",
|
||||
port: "8000",
|
||||
pathname: "/api/v1/image-generations/**",
|
||||
},
|
||||
{
|
||||
protocol: "https",
|
||||
hostname: "**",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue