diff --git a/surfsense_backend/app/agents/new_chat/llm_config.py b/surfsense_backend/app/agents/new_chat/llm_config.py index 99bb719f6..bc37bf1c4 100644 --- a/surfsense_backend/app/agents/new_chat/llm_config.py +++ b/surfsense_backend/app/agents/new_chat/llm_config.py @@ -90,41 +90,18 @@ class SanitizedChatLiteLLM(ChatLiteLLM): yield chunk -# Provider mapping for LiteLLM model string construction -PROVIDER_MAP = { - "OPENAI": "openai", - "ANTHROPIC": "anthropic", - "GROQ": "groq", - "COHERE": "cohere", - "GOOGLE": "gemini", - "OLLAMA": "ollama_chat", - "MISTRAL": "mistral", - "AZURE_OPENAI": "azure", - "OPENROUTER": "openrouter", - "XAI": "xai", - "BEDROCK": "bedrock", - "VERTEX_AI": "vertex_ai", - "TOGETHER_AI": "together_ai", - "FIREWORKS_AI": "fireworks_ai", - "DEEPSEEK": "openai", - "ALIBABA_QWEN": "openai", - "MOONSHOT": "openai", - "ZHIPU": "openai", - "GITHUB_MODELS": "github", - "REPLICATE": "replicate", - "PERPLEXITY": "perplexity", - "ANYSCALE": "anyscale", - "DEEPINFRA": "deepinfra", - "CEREBRAS": "cerebras", - "SAMBANOVA": "sambanova", - "AI21": "ai21", - "CLOUDFLARE": "cloudflare", - "DATABRICKS": "databricks", - "COMETAPI": "cometapi", - "HUGGINGFACE": "huggingface", - "MINIMAX": "openai", - "CUSTOM": "custom", -} +# Provider mapping for LiteLLM model string construction. +# +# Single source of truth lives in +# :mod:`app.services.provider_capabilities` so the YAML loader (which +# runs during ``app.config`` class-body init) can resolve provider +# prefixes without dragging the agent / tools tree into module load +# order. Re-exported here under the historical ``PROVIDER_MAP`` name +# so existing callers (``llm_router_service``, ``image_gen_router_service``, +# tests) keep working unchanged. +from app.services.provider_capabilities import ( # noqa: E402 + _PROVIDER_PREFIX_MAP as PROVIDER_MAP, +) def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None: @@ -178,6 +155,17 @@ class AgentConfig: anonymous_enabled: bool = False quota_reserve_tokens: int | None = None + # Capability flag: best-effort True for the chat selector / catalog. + # Resolved via :func:`provider_capabilities.derive_supports_image_input` + # which prefers OpenRouter's ``architecture.input_modalities`` and + # otherwise consults LiteLLM's authoritative model map. Default True + # is the conservative-allow stance — the streaming-task safety net + # (``is_known_text_only_chat_model``) is the *only* place a False + # actually blocks a request. Setting this to False here without an + # authoritative source would silently hide vision-capable models + # (the regression we're fixing). + supports_image_input: bool = True + @classmethod def from_auto_mode(cls) -> "AgentConfig": """ @@ -203,6 +191,12 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # Auto routes across the configured pool, which usually + # contains at least one vision-capable deployment; the router + # will surface a 404 from a non-vision deployment as a normal + # ``allowed_fails`` event and fail over rather than blocking + # the request outright. + supports_image_input=True, ) @classmethod @@ -216,10 +210,24 @@ class AgentConfig: Returns: AgentConfig instance """ - return cls( - provider=config.provider.value + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + + provider_value = ( + config.provider.value if hasattr(config.provider, "value") - else str(config.provider), + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + return cls( + provider=provider_value, model_name=config.model_name, api_key=config.api_key, api_base=config.api_base, @@ -235,6 +243,16 @@ class AgentConfig: is_premium=False, anonymous_enabled=False, quota_reserve_tokens=None, + # BYOK rows have no operator-curated capability flag, so we + # ask LiteLLM (default-allow on unknown). The streaming + # safety net still blocks if the model is *explicitly* + # marked text-only. + supports_image_input=derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ), ) @classmethod @@ -253,15 +271,46 @@ class AgentConfig: Returns: AgentConfig instance """ + # Lazy import to avoid pulling provider_capabilities (and its + # transitive litellm import) into module-init order. + from app.services.provider_capabilities import derive_supports_image_input + # Get system instructions from YAML, default to empty string system_instructions = yaml_config.get("system_instructions", "") + provider = yaml_config.get("provider", "").upper() + model_name = yaml_config.get("model_name", "") + custom_provider = yaml_config.get("custom_provider") + litellm_params = yaml_config.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + + # Explicit YAML override wins; otherwise derive from LiteLLM / + # OpenRouter modalities. The YAML loader already populates this + # field, but this method is also called from + # ``load_global_llm_config_by_id``'s file fallback (hot reload), + # so we re-derive here for safety. The bool() coercion preserves + # the loader's behaviour for explicit ``true`` / ``false`` + # strings that PyYAML may surface. + if "supports_image_input" in yaml_config: + supports_image_input = bool(yaml_config.get("supports_image_input")) + else: + supports_image_input = derive_supports_image_input( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ) + return cls( - provider=yaml_config.get("provider", "").upper(), - model_name=yaml_config.get("model_name", ""), + provider=provider, + model_name=model_name, api_key=yaml_config.get("api_key", ""), api_base=yaml_config.get("api_base"), - custom_provider=yaml_config.get("custom_provider"), + custom_provider=custom_provider, litellm_params=yaml_config.get("litellm_params"), # Prompt configuration from YAML (with defaults for backwards compatibility) system_instructions=system_instructions if system_instructions else None, @@ -276,6 +325,7 @@ class AgentConfig: is_premium=yaml_config.get("billing_tier", "free") == "premium", anonymous_enabled=yaml_config.get("anonymous_enabled", False), quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"), + supports_image_input=supports_image_input, ) diff --git a/surfsense_backend/app/agents/new_chat/tools/generate_image.py b/surfsense_backend/app/agents/new_chat/tools/generate_image.py index 3803fa39c..9e287ac51 100644 --- a/surfsense_backend/app/agents/new_chat/tools/generate_image.py +++ b/surfsense_backend/app/agents/new_chat/tools/generate_image.py @@ -31,6 +31,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -49,12 +50,16 @@ _PROVIDER_MAP = { } +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) + prefix = _resolve_provider_prefix(provider, custom_provider) return f"{prefix}/{model_name}" @@ -146,14 +151,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -175,14 +184,18 @@ def create_generate_image_tool( "error": f"Image generation config {config_id} not found" } - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 2aeeafb34..97b4cf509 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -47,11 +47,37 @@ def load_global_llm_configs(): data = yaml.safe_load(f) configs = data.get("global_llm_configs", []) + # Lazy import keeps the `app.config` -> `app.services` edge one-way + # and matches the `provider_api_base` pattern used elsewhere. + from app.services.provider_capabilities import derive_supports_image_input + seen_slugs: dict[str, int] = {} for cfg in configs: cfg.setdefault("billing_tier", "free") cfg.setdefault("anonymous_enabled", False) cfg.setdefault("seo_enabled", False) + # Capability flag: explicit YAML override always wins. When the + # operator has not annotated the model, defer to LiteLLM's + # authoritative model map (`supports_vision`) which already + # knows GPT-5.x / GPT-4o / Claude 3.x / Gemini 2.x are + # vision-capable. Unknown / unmapped models default-allow so + # we don't lock the user out of a freshly added third-party + # entry; the streaming-task safety net (driven by + # `is_known_text_only_chat_model`) is the only place a False + # actually blocks a request. + if "supports_image_input" not in cfg: + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") + if isinstance(litellm_params, dict) + else None + ) + cfg["supports_image_input"] = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) if cfg.get("seo_enabled") and cfg.get("seo_slug"): slug = cfg["seo_slug"] diff --git a/surfsense_backend/app/routes/image_generation_routes.py b/surfsense_backend/app/routes/image_generation_routes.py index 34ed80207..018234ad5 100644 --- a/surfsense_backend/app/routes/image_generation_routes.py +++ b/surfsense_backend/app/routes/image_generation_routes.py @@ -46,6 +46,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.users import current_active_user from app.utils.rbac import check_permission from app.utils.signed_image_urls import verify_image_token @@ -87,14 +88,18 @@ def _get_global_image_gen_config(config_id: int) -> dict | None: return None +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + """Resolve the LiteLLM provider prefix used in model strings.""" + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: """Build a litellm model string from provider + model_name.""" - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" async def _resolve_billing_for_image_gen( @@ -187,12 +192,18 @@ async def _execute_image_generation( if not cfg: raise ValueError(f"Global image generation config {config_id} not found") - model_string = _build_model_string( - cfg.get("provider", ""), cfg["model_name"], cfg.get("custom_provider") + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -214,12 +225,18 @@ async def _execute_image_generation( if not db_cfg: raise ValueError(f"Image generation config {config_id} not found") - model_string = _build_model_string( - db_cfg.provider.value, db_cfg.model_name, db_cfg.custom_provider + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: @@ -277,10 +294,12 @@ async def get_global_image_gen_configs( # Auto mode currently treated as free until per-deployment # billing-tier surfacing lands (see _resolve_billing_for_image_gen). "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -293,7 +312,11 @@ async def get_global_image_gen_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", "quota_reserve_micros": cfg.get("quota_reserve_micros"), } ) diff --git a/surfsense_backend/app/routes/new_llm_config_routes.py b/surfsense_backend/app/routes/new_llm_config_routes.py index 20779a309..e090a1a7c 100644 --- a/surfsense_backend/app/routes/new_llm_config_routes.py +++ b/surfsense_backend/app/routes/new_llm_config_routes.py @@ -29,6 +29,7 @@ from app.schemas import ( NewLLMConfigUpdate, ) from app.services.llm_service import validate_llm_config +from app.services.provider_capabilities import derive_supports_image_input from app.users import current_active_user from app.utils.rbac import check_permission @@ -36,6 +37,39 @@ router = APIRouter() logger = logging.getLogger(__name__) +def _serialize_byok_config(config: NewLLMConfig) -> NewLLMConfigRead: + """Augment a BYOK chat config row with the derived ``supports_image_input``. + + There is no DB column for ``supports_image_input`` — the value is + resolved at the API boundary from LiteLLM's authoritative model map + (default-allow on unknown). Returning ``NewLLMConfigRead`` here keeps + the response shape consistent across list / detail / create / update + endpoints without having to remember to set the field at every call + site. + """ + provider_value = ( + config.provider.value + if hasattr(config.provider, "value") + else str(config.provider) + ) + litellm_params = config.litellm_params or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + supports_image_input = derive_supports_image_input( + provider=provider_value, + model_name=config.model_name, + base_model=base_model, + custom_provider=config.custom_provider, + ) + # ``model_validate`` runs the Pydantic conversion using the ORM + # attribute access path enabled by ``ConfigDict(from_attributes=True)``, + # then we layer the derived field on. ``model_copy(update=...)`` keeps + # the surface immutable from the caller's perspective. + base_read = NewLLMConfigRead.model_validate(config) + return base_read.model_copy(update={"supports_image_input": supports_image_input}) + + # ============================================================================= # Global Configs Routes # ============================================================================= @@ -84,11 +118,41 @@ async def get_global_new_llm_configs( "seo_title": None, "seo_description": None, "quota_reserve_tokens": None, + # Auto routes across the configured pool, which usually + # includes at least one vision-capable deployment, so + # treat Auto as image-capable. The router itself will + # still pick a vision-capable deployment for messages + # carrying image_url blocks (LiteLLM Router falls back + # on ``404`` per its ``allowed_fails`` policy). + "supports_image_input": True, } ) # Add individual global configs for cfg in global_configs: + # Capability resolution: explicit value (YAML override or OR + # `_supports_image_input(model)` payload baked in by the + # OpenRouter integration service) wins. Fall back to the + # LiteLLM-driven helper which default-allows on unknown so + # we don't hide vision-capable models that happen to lack a + # YAML annotation. The streaming task safety net is the + # only place a False ever blocks. + if "supports_image_input" in cfg: + supports_image_input = bool(cfg.get("supports_image_input")) + else: + cfg_litellm_params = cfg.get("litellm_params") or {} + cfg_base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + supports_image_input = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=cfg_base_model, + custom_provider=cfg.get("custom_provider"), + ) + safe_config = { "id": cfg.get("id"), "name": cfg.get("name"), @@ -113,6 +177,7 @@ async def get_global_new_llm_configs( "seo_title": cfg.get("seo_title"), "seo_description": cfg.get("seo_description"), "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), + "supports_image_input": supports_image_input, } safe_configs.append(safe_config) @@ -171,7 +236,7 @@ async def create_new_llm_config( await session.commit() await session.refresh(db_config) - return db_config + return _serialize_byok_config(db_config) except HTTPException: raise @@ -213,7 +278,7 @@ async def list_new_llm_configs( .limit(limit) ) - return result.scalars().all() + return [_serialize_byok_config(cfg) for cfg in result.scalars().all()] except HTTPException: raise @@ -268,7 +333,7 @@ async def get_new_llm_config( "You don't have permission to view LLM configurations in this search space", ) - return config + return _serialize_byok_config(config) except HTTPException: raise @@ -360,7 +425,7 @@ async def update_new_llm_config( await session.commit() await session.refresh(config) - return config + return _serialize_byok_config(config) except HTTPException: raise diff --git a/surfsense_backend/app/routes/vision_llm_routes.py b/surfsense_backend/app/routes/vision_llm_routes.py index 4f7e9f725..e4f08f604 100644 --- a/surfsense_backend/app/routes/vision_llm_routes.py +++ b/surfsense_backend/app/routes/vision_llm_routes.py @@ -85,10 +85,12 @@ async def get_global_vision_llm_configs( # Auto mode treated as free until per-deployment billing-tier # surfacing lands; see ``get_vision_llm`` for parity. "billing_tier": "free", + "is_premium": False, } ) for cfg in global_configs: + billing_tier = str(cfg.get("billing_tier", "free")).lower() safe_configs.append( { "id": cfg.get("id"), @@ -101,7 +103,11 @@ async def get_global_vision_llm_configs( "api_version": cfg.get("api_version") or None, "litellm_params": cfg.get("litellm_params", {}), "is_global": True, - "billing_tier": cfg.get("billing_tier", "free"), + "billing_tier": billing_tier, + # Mirror chat (``new_llm_config_routes``) so the new-chat + # selector's premium badge logic keys off the same + # field across chat / image / vision tabs. + "is_premium": billing_tier == "premium", "quota_reserve_tokens": cfg.get("quota_reserve_tokens"), "input_cost_per_token": cfg.get("input_cost_per_token"), "output_cost_per_token": cfg.get("output_cost_per_token"), diff --git a/surfsense_backend/app/schemas/image_generation.py b/surfsense_backend/app/schemas/image_generation.py index facca7b86..4262b2b3f 100644 --- a/surfsense_backend/app/schemas/image_generation.py +++ b/surfsense_backend/app/schemas/image_generation.py @@ -241,6 +241,15 @@ class GlobalImageGenConfigRead(BaseModel): default="free", description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) quota_reserve_micros: int | None = Field( default=None, description=( diff --git a/surfsense_backend/app/schemas/new_llm_config.py b/surfsense_backend/app/schemas/new_llm_config.py index 9cc1fce58..e64478d38 100644 --- a/surfsense_backend/app/schemas/new_llm_config.py +++ b/surfsense_backend/app/schemas/new_llm_config.py @@ -92,6 +92,20 @@ class NewLLMConfigRead(NewLLMConfigBase): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (no DB column). Default + # True matches the conservative-allow stance — a BYOK row that the + # route forgot to augment is not pre-judged. The streaming-task + # safety net is the only place a False actually blocks a request. + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map " + "(``litellm.supports_vision``) — there is no DB column. " + "Default True is the conservative-allow stance for unknown / " + "unmapped models." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -121,6 +135,15 @@ class NewLLMConfigPublic(BaseModel): created_at: datetime search_space_id: int user_id: uuid.UUID + # Capability flag derived at the API boundary (see NewLLMConfigRead). + supports_image_input: bool = Field( + default=True, + description=( + "Whether the BYOK chat config can accept image inputs. Derived " + "at the route boundary from LiteLLM's authoritative model map. " + "Default True is the conservative-allow stance." + ), + ) model_config = ConfigDict(from_attributes=True) @@ -172,6 +195,19 @@ class GlobalNewLLMConfigRead(BaseModel): seo_title: str | None = None seo_description: str | None = None quota_reserve_tokens: int | None = None + supports_image_input: bool = Field( + default=True, + description=( + "Whether the model accepts image inputs (multimodal vision). " + "Derived server-side: OpenRouter dynamic configs use " + "``architecture.input_modalities``; YAML / BYOK use LiteLLM's " + "authoritative model map (``litellm.supports_vision``). The " + "new-chat selector hints with a 'No image' badge when this is " + "False and there are pending image attachments. The streaming " + "task fails fast only when LiteLLM *explicitly* marks a model " + "as text-only — unknown / unmapped models default-allow." + ), + ) # ============================================================================= diff --git a/surfsense_backend/app/schemas/vision_llm.py b/surfsense_backend/app/schemas/vision_llm.py index e55333a9d..d0eeaf5c6 100644 --- a/surfsense_backend/app/schemas/vision_llm.py +++ b/surfsense_backend/app/schemas/vision_llm.py @@ -86,6 +86,15 @@ class GlobalVisionLLMConfigRead(BaseModel): default="free", description="'free' or 'premium'. Premium debits the user's premium credit pool (USD-cost-based).", ) + is_premium: bool = Field( + default=False, + description=( + "Convenience boolean derived server-side from " + "``billing_tier == 'premium'``. The new-chat model selector " + "keys its Free/Premium badge off this field for parity with " + "chat (`GlobalLLMConfigRead.is_premium`)." + ), + ) quota_reserve_tokens: int | None = Field( default=None, description=( diff --git a/surfsense_backend/app/services/auto_model_pin_service.py b/surfsense_backend/app/services/auto_model_pin_service.py index 3a2c681b7..4f045ba02 100644 --- a/surfsense_backend/app/services/auto_model_pin_service.py +++ b/surfsense_backend/app/services/auto_model_pin_service.py @@ -163,13 +163,47 @@ def clear_healthy(config_id: int | None = None) -> None: _healthy_until.pop(int(config_id), None) -def _global_candidates() -> list[dict]: +def _cfg_supports_image_input(cfg: dict) -> bool: + """True if the global cfg can accept image inputs. + + Prefers the explicit ``supports_image_input`` flag (set by the YAML + loader / OpenRouter integration). Falls back to a LiteLLM lookup so + a YAML entry whose flag was somehow stripped doesn't get wrongly + excluded. Default-allows on unknown — the streaming-task safety net + is the actual block, not this filter. + """ + if "supports_image_input" in cfg: + return bool(cfg.get("supports_image_input")) + # Lazy import: provider_capabilities -> llm_config -> services chain; + # importing at module load would create an init-order cycle through + # ``app.config``. + from app.services.provider_capabilities import derive_supports_image_input + + cfg_litellm_params = cfg.get("litellm_params") or {} + base_model = ( + cfg_litellm_params.get("base_model") + if isinstance(cfg_litellm_params, dict) + else None + ) + return derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + + +def _global_candidates(*, requires_image_input: bool = False) -> list[dict]: """Return Auto-eligible global cfgs. Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers can't be picked as the thread's pin. Also excludes configs currently in runtime cooldown (e.g. temporary 429 bursts). + + When ``requires_image_input`` is True (image turn), additionally + filters out configs whose ``supports_image_input`` resolves to False + so a text-only deployment can't be pinned for an image request. """ candidates = [ cfg @@ -177,6 +211,7 @@ def _global_candidates() -> list[dict]: if _is_usable_global_config(cfg) and not cfg.get("health_gated") and not _is_runtime_cooled_down(int(cfg.get("id", 0))) + and (not requires_image_input or _cfg_supports_image_input(cfg)) ] return sorted(candidates, key=lambda c: int(c.get("id", 0))) @@ -237,11 +272,20 @@ async def resolve_or_get_pinned_llm_config_id( selected_llm_config_id: int, force_repin_free: bool = False, exclude_config_ids: set[int] | None = None, + requires_image_input: bool = False, ) -> AutoPinResolution: """Resolve Auto (Fastest) to one concrete config id and persist the pin. For non-auto selections, this function clears any existing pin and returns the selected id as-is. + + When ``requires_image_input`` is True (the current turn carries an + ``image_url`` block), the candidate pool is filtered to vision-capable + cfgs and any existing pin that can't accept image input is treated as + invalid (force re-pin). If no vision-capable cfg is available the + function raises ``ValueError`` so the streaming task surfaces the same + friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` error instead of + silently routing the image to a text-only deployment. """ thread = ( ( @@ -274,14 +318,24 @@ async def resolve_or_get_pinned_llm_config_id( excluded_ids = {int(cid) for cid in (exclude_config_ids or set())} candidates = [ - c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids + c + for c in _global_candidates(requires_image_input=requires_image_input) + if int(c.get("id", 0)) not in excluded_ids ] if not candidates: + if requires_image_input: + # Distinguish the "no vision-capable cfg" case from generic + # "no usable cfg" so the streaming task can map this to the + # MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT SSE error. + raise ValueError( + "No vision-capable global LLM configs are available for Auto mode" + ) raise ValueError("No usable global LLM configs are available for Auto mode") candidate_by_id = {int(c["id"]): c for c in candidates} # Reuse an existing valid pin without re-checking current quota (no silent - # tier switch), unless the caller explicitly requests a forced repin to free. + # tier switch), unless the caller explicitly requests a forced repin to free + # *or* the turn requires image input but the pin can't handle it. pinned_id = thread.pinned_llm_config_id if ( not force_repin_free @@ -311,6 +365,29 @@ async def resolve_or_get_pinned_llm_config_id( from_existing_pin=True, ) if pinned_id is not None: + # If the pin is *only* invalid because it can't handle the image + # turn (it's still a healthy, usable config in the broader pool), + # log that explicitly so operators can correlate the re-pin with + # the user's image attachment instead of suspecting a cooldown. + if requires_image_input: + try: + pinned_global = next( + c + for c in config.GLOBAL_LLM_CONFIGS + if int(c.get("id", 0)) == int(pinned_id) + ) + except StopIteration: + pinned_global = None + if pinned_global is not None and not _cfg_supports_image_input( + pinned_global + ): + logger.info( + "auto_pin_repinned_for_image thread_id=%s search_space_id=%s " + "previous_config_id=%s", + thread_id, + search_space_id, + pinned_id, + ) logger.info( "auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s", thread_id, @@ -327,6 +404,10 @@ async def resolve_or_get_pinned_llm_config_id( eligible = [c for c in candidates if _tier_of(c) != "premium"] if not eligible: + if requires_image_input: + raise ValueError( + "Auto mode could not find a vision-capable LLM config for this user and quota state" + ) raise ValueError( "Auto mode could not find an eligible LLM config for this user and quota state" ) diff --git a/surfsense_backend/app/services/billable_calls.py b/surfsense_backend/app/services/billable_calls.py index f5ca9818e..92ccd6a78 100644 --- a/surfsense_backend/app/services/billable_calls.py +++ b/surfsense_backend/app/services/billable_calls.py @@ -10,12 +10,14 @@ vision-LLM wrapper used during indexing) don't have to re-implement it. KEY DESIGN POINTS (issue A, B): -1. **Session isolation.** ``billable_call`` takes *no* ``db_session`` - argument. All ``TokenQuotaService.premium_*`` calls and the audit-row - insert each run inside their own ``shielded_async_session()``. This - guarantees that a quota commit/rollback can never accidentally flush or - roll back rows the caller has staged in the request's main session - (e.g. a freshly-created ``ImageGeneration`` row). +1. **Session isolation.** ``billable_call`` takes no caller transaction. + All ``TokenQuotaService.premium_*`` calls and the audit-row insert run + inside their own session context. Route callers use + ``shielded_async_session()`` by default; Celery callers can provide a + worker-loop-safe session factory. This guarantees that quota + commit/rollback can never accidentally flush or roll back rows the caller + has staged in its main session (e.g. a freshly-created + ``ImageGeneration`` row). 2. **ContextVar safety.** The accumulator is scoped via :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a @@ -36,9 +38,10 @@ KEY DESIGN POINTS (issue A, B): from __future__ import annotations +import asyncio import logging -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Callable +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress from typing import Any from uuid import UUID, uuid4 @@ -58,6 +61,12 @@ from app.services.token_tracking_service import ( logger = logging.getLogger(__name__) +AUDIT_TIMEOUT_SECONDS = 10.0 +BACKGROUND_ARTIFACT_USAGE_TYPES = frozenset( + {"video_presentation_generation", "podcast_generation"} +) +BillableSessionFactory = Callable[[], AbstractAsyncContextManager[AsyncSession]] + class QuotaInsufficientError(Exception): """Raised when ``TokenQuotaService.premium_reserve`` denies a billable @@ -88,6 +97,124 @@ class QuotaInsufficientError(Exception): ) +class BillingSettlementError(Exception): + """Raised when a premium call completed but credit settlement failed.""" + + def __init__(self, *, usage_type: str, user_id: UUID, cause: Exception) -> None: + self.usage_type = usage_type + self.user_id = user_id + super().__init__( + f"Failed to settle premium credit for {usage_type} user={user_id}: {cause}" + ) + + +async def _rollback_safely(session: AsyncSession) -> None: + rollback = getattr(session, "rollback", None) + if rollback is not None: + with suppress(Exception): + await rollback() + + +async def _record_audit_best_effort( + *, + session_factory: BillableSessionFactory, + usage_type: str, + search_space_id: int, + user_id: UUID, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + cost_micros: int, + model_breakdown: dict[str, Any], + call_details: dict[str, Any] | None, + thread_id: int | None, + message_id: int | None, + audit_label: str, + timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, +) -> None: + """Persist a TokenUsage row without letting audit failure block callers. + + Premium settlement is mandatory, but TokenUsage is an audit trail. If the + audit insert or commit hangs, user-facing artifacts such as videos and + podcasts must still be able to transition to READY after settlement. + """ + audit_thread_id = ( + None if usage_type in BACKGROUND_ARTIFACT_USAGE_TYPES else thread_id + ) + + async def _persist() -> None: + logger.info( + "[billable_call] audit start label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + async with session_factory() as audit_session: + try: + await record_token_usage( + audit_session, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost_micros=cost_micros, + model_breakdown=model_breakdown, + call_details=call_details, + thread_id=audit_thread_id, + message_id=message_id, + ) + logger.info( + "[billable_call] audit row staged label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + await audit_session.commit() + logger.info( + "[billable_call] audit commit OK label=%s usage_type=%s user=%s thread=%s", + audit_label, + usage_type, + user_id, + audit_thread_id, + ) + except BaseException: + await _rollback_safely(audit_session) + raise + + try: + await asyncio.wait_for(_persist(), timeout=timeout_seconds) + except TimeoutError: + logger.warning( + "[billable_call] audit timed out label=%s usage_type=%s user=%s thread=%s " + "timeout=%.1fs total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + timeout_seconds, + total_tokens, + cost_micros, + ) + except Exception: + logger.exception( + "[billable_call] audit failed label=%s usage_type=%s user=%s thread=%s " + "total_tokens=%d cost_micros=%d", + audit_label, + usage_type, + user_id, + audit_thread_id, + total_tokens, + cost_micros, + ) + + @asynccontextmanager async def billable_call( *, @@ -101,6 +228,8 @@ async def billable_call( thread_id: int | None = None, message_id: int | None = None, call_details: dict[str, Any] | None = None, + billable_session_factory: BillableSessionFactory | None = None, + audit_timeout_seconds: float = AUDIT_TIMEOUT_SECONDS, ) -> AsyncIterator[TurnTokenAccumulator]: """Wrap a single billable LLM/image call. @@ -124,6 +253,13 @@ async def billable_call( thread_id, message_id: Optional FK columns on ``TokenUsage``. call_details: Optional per-call metadata (model name, parameters) forwarded to ``record_token_usage``. + billable_session_factory: Optional async context factory used for + reserve/finalize/release/audit sessions. Defaults to + ``shielded_async_session`` for route callers; Celery callers pass + a worker-loop-safe session factory. + audit_timeout_seconds: Upper bound for TokenUsage audit persistence. + Audit failure is best-effort and does not undo successful + settlement. Yields: The ``TurnTokenAccumulator`` scoped to this call. The caller invokes @@ -134,6 +270,7 @@ async def billable_call( QuotaInsufficientError: when premium and ``premium_reserve`` denies. """ is_premium = billing_tier == "premium" + session_factory = billable_session_factory or shielded_async_session async with scoped_turn() as acc: # ---------- Free path: just audit ------------------------------- @@ -143,30 +280,22 @@ async def billable_call( finally: # Always audit, even on exception, so we capture cost when # provider returns successfully but the caller raises later. - try: - async with shielded_async_session() as audit_session: - await record_token_usage( - audit_session, - usage_type=usage_type, - search_space_id=search_space_id, - user_id=user_id, - prompt_tokens=acc.total_prompt_tokens, - completion_tokens=acc.total_completion_tokens, - total_tokens=acc.grand_total, - cost_micros=acc.total_cost_micros, - model_breakdown=acc.per_message_summary(), - call_details=call_details, - thread_id=thread_id, - message_id=message_id, - ) - await audit_session.commit() - except Exception: - logger.exception( - "[billable_call] free-path audit insert failed for " - "usage_type=%s user_id=%s", - usage_type, - user_id, - ) + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=acc.total_cost_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="free", + timeout_seconds=audit_timeout_seconds, + ) return # ---------- Premium path: reserve → execute → finalize ---------- @@ -180,7 +309,7 @@ async def billable_call( request_id = str(uuid4()) - async with shielded_async_session() as quota_session: + async with session_factory() as quota_session: reserve_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=user_id, @@ -222,7 +351,7 @@ async def billable_call( # from a downstream call, asyncio cancellation, etc.). We use # BaseException so cancellation also releases. try: - async with shielded_async_session() as quota_session: + async with session_factory() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, @@ -241,7 +370,16 @@ async def billable_call( # ---------- Success: finalize + audit ---------------------------- actual_micros = acc.total_cost_micros try: - async with shielded_async_session() as quota_session: + logger.info( + "[billable_call] finalize start user=%s usage_type=%s actual=%d " + "reserved=%d thread=%s", + user_id, + usage_type, + actual_micros, + reserve_micros, + thread_id, + ) + async with session_factory() as quota_session: final_result = await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=user_id, @@ -260,7 +398,7 @@ async def billable_call( final_result.limit, final_result.remaining, ) - except Exception: + except Exception as finalize_exc: # Last-ditch: if finalize itself fails, we must at least release # so the reservation doesn't leak. logger.exception( @@ -269,7 +407,7 @@ async def billable_call( user_id, ) try: - async with shielded_async_session() as quota_session: + async with session_factory() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, @@ -281,31 +419,28 @@ async def billable_call( "for user=%s", user_id, ) + raise BillingSettlementError( + usage_type=usage_type, + user_id=user_id, + cause=finalize_exc, + ) from finalize_exc - try: - async with shielded_async_session() as audit_session: - await record_token_usage( - audit_session, - usage_type=usage_type, - search_space_id=search_space_id, - user_id=user_id, - prompt_tokens=acc.total_prompt_tokens, - completion_tokens=acc.total_completion_tokens, - total_tokens=acc.grand_total, - cost_micros=actual_micros, - model_breakdown=acc.per_message_summary(), - call_details=call_details, - thread_id=thread_id, - message_id=message_id, - ) - await audit_session.commit() - except Exception: - logger.exception( - "[billable_call] premium-path audit insert failed for " - "usage_type=%s user_id=%s (debit was applied)", - usage_type, - user_id, - ) + await _record_audit_best_effort( + session_factory=session_factory, + usage_type=usage_type, + search_space_id=search_space_id, + user_id=user_id, + prompt_tokens=acc.total_prompt_tokens, + completion_tokens=acc.total_completion_tokens, + total_tokens=acc.grand_total, + cost_micros=actual_micros, + model_breakdown=acc.per_message_summary(), + call_details=call_details, + thread_id=thread_id, + message_id=message_id, + audit_label="premium", + timeout_seconds=audit_timeout_seconds, + ) async def _resolve_agent_billing_for_search_space( @@ -419,6 +554,7 @@ async def _resolve_agent_billing_for_search_space( __all__ = [ + "BillingSettlementError", "QuotaInsufficientError", "_resolve_agent_billing_for_search_space", "billable_call", diff --git a/surfsense_backend/app/services/image_gen_router_service.py b/surfsense_backend/app/services/image_gen_router_service.py index f45a6ab63..b4de2a0bf 100644 --- a/surfsense_backend/app/services/image_gen_router_service.py +++ b/surfsense_backend/app/services/image_gen_router_service.py @@ -20,6 +20,8 @@ from typing import Any from litellm import Router from litellm.utils import ImageResponse +from app.services.provider_api_base import resolve_api_base + logger = logging.getLogger(__name__) # Special ID for Auto mode - uses router for load balancing @@ -152,12 +154,12 @@ class ImageGenRouterService: return None # Build model string + provider = config.get("provider", "").upper() if config.get("custom_provider"): - model_string = f"{config['custom_provider']}/{config['model_name']}" + provider_prefix = config["custom_provider"] else: - provider = config.get("provider", "").upper() provider_prefix = IMAGE_GEN_PROVIDER_MAP.get(provider, provider.lower()) - model_string = f"{provider_prefix}/{config['model_name']}" + model_string = f"{provider_prefix}/{config['model_name']}" # Build litellm params litellm_params: dict[str, Any] = { @@ -165,9 +167,16 @@ class ImageGenRouterService: "api_key": config.get("api_key"), } - # Add optional api_base - if config.get("api_base"): - litellm_params["api_base"] = config["api_base"] + # Resolve ``api_base`` so deployments don't silently inherit + # ``AZURE_OPENAI_ENDPOINT`` / ``OPENAI_API_BASE`` and 404 against + # the wrong provider (see ``provider_api_base`` docstring). + api_base = resolve_api_base( + provider=provider, + provider_prefix=provider_prefix, + config_api_base=config.get("api_base"), + ) + if api_base: + litellm_params["api_base"] = api_base # Add api_version (required for Azure) if config.get("api_version"): diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index 1e9d235c8..d220aa346 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -140,8 +140,6 @@ PROVIDER_MAP = { # 404-ing against an inherited Azure endpoint). Re-exported here for # backward compatibility with any external import. from app.services.provider_api_base import ( # noqa: E402 - PROVIDER_DEFAULT_API_BASE, - PROVIDER_KEY_DEFAULT_API_BASE, resolve_api_base, ) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 72c10035d..ade202c72 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -16,6 +16,7 @@ from app.services.llm_router_service import ( get_auto_mode_llm, is_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import token_tracker # Configure litellm to automatically drop unsupported parameters @@ -556,22 +557,26 @@ async def get_vision_llm( return None if global_cfg.get("custom_provider"): - model_string = ( - f"{global_cfg['custom_provider']}/{global_cfg['model_name']}" - ) + provider_prefix = global_cfg["custom_provider"] + model_string = f"{provider_prefix}/{global_cfg['model_name']}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( global_cfg["provider"].upper(), global_cfg["provider"].lower(), ) - model_string = f"{prefix}/{global_cfg['model_name']}" + model_string = f"{provider_prefix}/{global_cfg['model_name']}" litellm_kwargs = { "model": model_string, "api_key": global_cfg["api_key"], } - if global_cfg.get("api_base"): - litellm_kwargs["api_base"] = global_cfg["api_base"] + api_base = resolve_api_base( + provider=global_cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=global_cfg.get("api_base"), + ) + if api_base: + litellm_kwargs["api_base"] = api_base if global_cfg.get("litellm_params"): litellm_kwargs.update(global_cfg["litellm_params"]) @@ -606,20 +611,26 @@ async def get_vision_llm( return None if vision_cfg.custom_provider: - model_string = f"{vision_cfg.custom_provider}/{vision_cfg.model_name}" + provider_prefix = vision_cfg.custom_provider + model_string = f"{provider_prefix}/{vision_cfg.model_name}" else: - prefix = VISION_PROVIDER_MAP.get( + provider_prefix = VISION_PROVIDER_MAP.get( vision_cfg.provider.value.upper(), vision_cfg.provider.value.lower(), ) - model_string = f"{prefix}/{vision_cfg.model_name}" + model_string = f"{provider_prefix}/{vision_cfg.model_name}" litellm_kwargs = { "model": model_string, "api_key": vision_cfg.api_key, } - if vision_cfg.api_base: - litellm_kwargs["api_base"] = vision_cfg.api_base + api_base = resolve_api_base( + provider=vision_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=vision_cfg.api_base, + ) + if api_base: + litellm_kwargs["api_base"] = api_base if vision_cfg.litellm_params: litellm_kwargs.update(vision_cfg.litellm_params) diff --git a/surfsense_backend/app/services/openrouter_integration_service.py b/surfsense_backend/app/services/openrouter_integration_service.py index 0d030f04f..6454e2d58 100644 --- a/surfsense_backend/app/services/openrouter_integration_service.py +++ b/surfsense_backend/app/services/openrouter_integration_service.py @@ -122,6 +122,24 @@ def _is_vision_input_model(model: dict) -> bool: return "image" in input_mods and "text" in output_mods +def _supports_image_input(model: dict) -> bool: + """Return True if the model accepts ``image`` in its input modalities. + + Differs from :func:`_is_vision_input_model` in that it does NOT + require text output — chat-tab models always emit text already (the + chat catalog filters by ``_is_text_output_model``), so the only + extra capability we need to track per chat config is whether the + model can ingest user-attached images. The chat selector and the + streaming task both key off this flag to prevent hitting an + OpenRouter 404 ``"No endpoints found that support image input"`` + when the user uploads an image and selects a text-only model + (DeepSeek V3, Llama 3.x base, etc.). + """ + arch = model.get("architecture", {}) or {} + input_mods = arch.get("input_modalities", []) or [] + return "image" in input_mods + + def _supports_tool_calling(model: dict) -> bool: """Return True if the model supports function/tool calling.""" supported = model.get("supported_parameters") or [] @@ -321,6 +339,13 @@ def _generate_configs( # account-wide quota, so per-deployment routing can't spread load # there — it just drains the shared bucket faster. "router_pool_eligible": tier == "premium", + # Capability flag derived from ``architecture.input_modalities``. + # Read by the new-chat selector to dim image-incompatible models + # when the user has pending image attachments, and by + # ``stream_new_chat`` as a fail-fast safety net before the + # OpenRouter request would otherwise 404 with + # ``"No endpoints found that support image input"``. + "supports_image_input": _supports_image_input(model), _OPENROUTER_DYNAMIC_MARKER: True, # Auto (Fastest) ranking metadata. ``quality_score`` is initialised # to the static score and gets re-blended with health on the next @@ -398,7 +423,12 @@ def _generate_image_gen_configs( "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, - "api_base": "", + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` and 404 on + # ``image_generation/transformation`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, "litellm_params": dict(litellm_params), @@ -477,7 +507,11 @@ def _generate_vision_llm_configs( "provider": "OPENROUTER", "model_name": model_id, "api_key": api_key, - "api_base": "", + # Pin to OpenRouter's public base URL so a downstream call site + # that forgets ``resolve_api_base`` still doesn't inherit + # ``AZURE_OPENAI_ENDPOINT`` (defense-in-depth, see + # ``provider_api_base`` docstring). + "api_base": "https://openrouter.ai/api/v1", "api_version": None, "rpm": free_rpm if tier == "free" else rpm, "tpm": free_tpm if tier == "free" else tpm, diff --git a/surfsense_backend/app/services/provider_api_base.py b/surfsense_backend/app/services/provider_api_base.py index 979d7d3a1..dca1f9462 100644 --- a/surfsense_backend/app/services/provider_api_base.py +++ b/surfsense_backend/app/services/provider_api_base.py @@ -17,7 +17,6 @@ source of truth without an inter-service circular import. from __future__ import annotations - PROVIDER_DEFAULT_API_BASE: dict[str, str] = { "openrouter": "https://openrouter.ai/api/v1", "groq": "https://api.groq.com/openai/v1", diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py new file mode 100644 index 000000000..e9a1c33e1 --- /dev/null +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -0,0 +1,280 @@ +"""Capability resolution shared by chat / image / vision call sites. + +Why this exists +--------------- +The chat catalog (YAML + dynamic OpenRouter + BYOK DB rows + Auto) needs a +single, authoritative answer to one question: *can this chat config accept +``image_url`` content blocks?* Without it, the new-chat selector can't badge +incompatible models and the streaming task can't fail fast with a friendly +error before sending an image to a text-only provider. + +Two functions, two intents: + +- :func:`derive_supports_image_input` — best-effort *True* for catalog and + UI surfacing. Default-allow: an unknown / unmapped model is treated as + capable so we never lock the user out of a freshly added or + third-party-hosted vision model. + +- :func:`is_known_text_only_chat_model` — strict opt-out for the streaming + task's safety net. Returns True only when LiteLLM's model map *explicitly* + sets ``supports_vision=False`` (or its bare-name variant does). Anything + else — missing key, lookup exception, ``supports_vision=True`` — returns + False so the request flows through to the provider. + +Implementation rule: only public LiteLLM symbols +------------------------------------------------ +``litellm.supports_vision`` and ``litellm.get_model_info`` are part of the +typed module surface (see ``litellm.__init__`` lazy stubs) and are stable +across releases. The private ``_is_explicitly_disabled_factory`` and +``_get_model_info_helper`` are intentionally avoided so a LiteLLM upgrade +can't silently break us. + +Why the previous round's strict YAML opt-in flag failed +------------------------------------------------------- +``supports_image_input: false`` was the YAML loader's setdefault. Operators +maintaining ``global_llm_config.yaml`` never set it, so every Azure / OpenAI +YAML chat model — including vision-capable GPT-5.x and GPT-4o — resolved to +False and the streaming gate rejected every image turn. Sourcing capability +from LiteLLM's authoritative model map (which already says +``azure/gpt-5.4 -> supports_vision=true``) removes that operator toil. +""" + +from __future__ import annotations + +import logging +from collections.abc import Iterable + +import litellm + +logger = logging.getLogger(__name__) + + +# Provider-name → LiteLLM model-prefix map. +# +# Owned here because ``app.services.provider_capabilities`` is the +# only edge that's safe to call from ``app.config``'s YAML loader at +# class-body init time. ``app.agents.new_chat.llm_config`` re-exports +# this constant under the historical ``PROVIDER_MAP`` name; placing the +# map there directly would re-introduce the +# ``app.config -> ... -> app.agents.new_chat.tools.generate_image -> +# app.config`` cycle that prompted the move. +_PROVIDER_PREFIX_MAP: dict[str, str] = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama_chat", + "MISTRAL": "mistral", + "AZURE_OPENAI": "azure", + "OPENROUTER": "openrouter", + "XAI": "xai", + "BEDROCK": "bedrock", + "VERTEX_AI": "vertex_ai", + "TOGETHER_AI": "together_ai", + "FIREWORKS_AI": "fireworks_ai", + "DEEPSEEK": "openai", + "ALIBABA_QWEN": "openai", + "MOONSHOT": "openai", + "ZHIPU": "openai", + "GITHUB_MODELS": "github", + "REPLICATE": "replicate", + "PERPLEXITY": "perplexity", + "ANYSCALE": "anyscale", + "DEEPINFRA": "deepinfra", + "CEREBRAS": "cerebras", + "SAMBANOVA": "sambanova", + "AI21": "ai21", + "CLOUDFLARE": "cloudflare", + "DATABRICKS": "databricks", + "COMETAPI": "cometapi", + "HUGGINGFACE": "huggingface", + "MINIMAX": "openai", + "CUSTOM": "custom", +} + + +def _candidate_model_strings( + *, + provider: str | None, + model_name: str | None, + base_model: str | None, + custom_provider: str | None, +) -> list[tuple[str, str | None]]: + """Return ``[(model_string, custom_llm_provider), ...]`` lookup candidates. + + LiteLLM's capability lookup is keyed by ``model`` + (optional) + ``custom_llm_provider``. Different config sources give us different + levels of detail, so we try the most-specific keys first and fall back + to bare model names so unannotated entries (e.g. an Azure deployment + pointing at ``gpt-5.4`` via ``litellm_params.base_model``) still hit the + map. Order matters — the first lookup that returns a definitive answer + wins for both helpers. + """ + candidates: list[tuple[str, str | None]] = [] + seen: set[tuple[str, str | None]] = set() + + def _add(model: str | None, llm_provider: str | None) -> None: + if not model: + return + key = (model, llm_provider) + if key in seen: + return + seen.add(key) + candidates.append(key) + + provider_prefix: str | None = None + if provider: + provider_prefix = _PROVIDER_PREFIX_MAP.get(provider.upper(), provider.lower()) + if custom_provider: + # ``custom_provider`` overrides everything for CUSTOM/proxy setups. + provider_prefix = custom_provider + + primary_model = base_model or model_name + bare_model = model_name + + # Most-specific first: provider-prefixed identifier with explicit + # custom_llm_provider so LiteLLM won't have to guess the provider via + # ``get_llm_provider``. + if primary_model and provider_prefix: + # e.g. "azure/gpt-5.4" + custom_llm_provider="azure" + if "/" in primary_model: + _add(primary_model, provider_prefix) + else: + _add(f"{provider_prefix}/{primary_model}", provider_prefix) + + # Bare base_model (or model_name) with provider hint — handles entries + # the upstream map keys without a provider prefix (most ``gpt-*`` and + # ``claude-*`` entries do this). + if primary_model: + _add(primary_model, provider_prefix) + + # Fallback to model_name when base_model differs (e.g. an Azure + # deployment whose model_name is the deployment id but base_model is the + # canonical OpenAI sku). + if bare_model and bare_model != primary_model: + if provider_prefix and "/" not in bare_model: + _add(f"{provider_prefix}/{bare_model}", provider_prefix) + _add(bare_model, provider_prefix) + _add(bare_model, None) + + return candidates + + +def derive_supports_image_input( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, + openrouter_input_modalities: Iterable[str] | None = None, +) -> bool: + """Best-effort capability flag for the new-chat selector and catalog. + + Resolution order (first definitive answer wins): + + 1. ``openrouter_input_modalities`` (when provided as a non-empty + iterable). OpenRouter exposes ``architecture.input_modalities`` per + model and that's the authoritative source for OR dynamic configs. + 2. ``litellm.supports_vision`` against each candidate identifier from + :func:`_candidate_model_strings`. Returns True as soon as any + candidate confirms vision support. + 3. Default ``True`` — the conservative-allow stance. An unknown / + newly-added / third-party-hosted model is *not* pre-judged. The + streaming safety net (:func:`is_known_text_only_chat_model`) is the + only place a False ever blocks; everywhere else, a False here would + just hide a usable model from the user. + + Returns: + True if the model can plausibly accept image input, False only when + OpenRouter explicitly says it can't. + """ + if openrouter_input_modalities is not None: + modalities = list(openrouter_input_modalities) + if modalities: + return "image" in modalities + # Empty list explicitly published by OR — treat as "no image". + return False + + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + if litellm.supports_vision( + model=model_string, custom_llm_provider=custom_llm_provider + ): + return True + except Exception as exc: + logger.debug( + "litellm.supports_vision raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # Default-allow. ``is_known_text_only_chat_model`` is the strict gate. + return True + + +def is_known_text_only_chat_model( + *, + provider: str | None = None, + model_name: str | None = None, + base_model: str | None = None, + custom_provider: str | None = None, +) -> bool: + """Strict opt-out probe for the streaming-task safety net. + + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False`` for at least one candidate identifier. Missing + key, lookup exception, or ``supports_vision=True`` all return False so + the streaming task lets the request through. This is the inverse-default + of :func:`derive_supports_image_input`. + + Why two functions + ----------------- + The selector wants "show me everything that's plausibly capable" — + default-allow. The safety net wants "block only when I'm certain it + can't" — default-pass. Mixing the two intents in a single function + leads to the regression we're fixing here. + """ + for model_string, custom_llm_provider in _candidate_model_strings( + provider=provider, + model_name=model_name, + base_model=base_model, + custom_provider=custom_provider, + ): + try: + info = litellm.get_model_info( + model=model_string, custom_llm_provider=custom_llm_provider + ) + except Exception as exc: + logger.debug( + "litellm.get_model_info raised for model=%s provider=%s: %s", + model_string, + custom_llm_provider, + exc, + ) + continue + + # ``ModelInfo`` is a TypedDict (dict at runtime). ``supports_vision`` + # may be missing, None, True, or False. We only fire on explicit + # False — None / missing / True all mean "don't block". + try: + value = info.get("supports_vision") # type: ignore[union-attr] + except AttributeError: + value = None + if value is False: + return True + + return False + + +__all__ = [ + "derive_supports_image_input", + "is_known_text_only_chat_model", +] diff --git a/surfsense_backend/app/tasks/celery_tasks/__init__.py b/surfsense_backend/app/tasks/celery_tasks/__init__.py index 5b1f2cd13..b23359f36 100644 --- a/surfsense_backend/app/tasks/celery_tasks/__init__.py +++ b/surfsense_backend/app/tasks/celery_tasks/__init__.py @@ -1,10 +1,25 @@ -"""Celery tasks package.""" +"""Celery tasks package. + +Also hosts the small helpers every async celery task should use to +spin up its event loop. See :func:`run_async_celery_task` for the +canonical pattern. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +from collections.abc import Awaitable, Callable +from typing import TypeVar from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool from app.config import config +logger = logging.getLogger(__name__) + _celery_engine = None _celery_session_maker = None @@ -26,3 +41,86 @@ def get_celery_session_maker() -> async_sessionmaker: _celery_engine, expire_on_commit=False ) return _celery_session_maker + + +def _dispose_shared_db_engine(loop: asyncio.AbstractEventLoop) -> None: + """Drop the shared ``app.db.engine`` connection pool synchronously. + + The shared engine (used by ``shielded_async_session`` and most + routes / services) is a module-level singleton with a real pool. + Each celery task creates a fresh ``asyncio`` event loop; asyncpg + connections cache a reference to whichever loop opened them. When + a subsequent task's loop pulls a stale connection from the pool, + SQLAlchemy's ``pool_pre_ping`` checkout crashes with:: + + AttributeError: 'NoneType' object has no attribute 'send' + File ".../asyncio/proactor_events.py", line 402, in _loop_writing + self._write_fut = self._loop._proactor.send(self._sock, data) + + or hangs forever inside the asyncpg ``Connection._cancel`` cleanup + coroutine that can never run because its loop is gone. + + Disposing the engine forces the pool to drop every cached + connection so the next checkout opens a fresh one on the current + loop. Safe to call from a task's finally block; failure is logged + but never propagated. + """ + try: + from app.db import engine as shared_engine + + loop.run_until_complete(shared_engine.dispose()) + except Exception: + logger.warning("Shared DB engine dispose() failed", exc_info=True) + + +T = TypeVar("T") + + +def run_async_celery_task[T](coro_factory: Callable[[], Awaitable[T]]) -> T: + """Run an async coroutine inside a fresh event loop with proper + DB-engine cleanup. + + This is the canonical entry point for every async celery task. + It performs three responsibilities that were previously copy-pasted + (incorrectly) across each task module: + + 1. Create a fresh ``asyncio`` loop and install it on the current + thread (celery's ``--pool=solo`` runs every task on the main + thread, but other pool types don't). + 2. Dispose the shared ``app.db.engine`` BEFORE the task runs so + any stale connections left over from a previous task's loop + are dropped — defends against tasks that crashed without + cleaning up. + 3. Dispose the shared engine AFTER the task runs so the + connections we opened on this loop are released before the + loop closes (avoids ``coroutine 'Connection._cancel' was + never awaited`` warnings and the next-task hang). + + Use as:: + + @celery_app.task(name="my_task", bind=True) + def my_task(self, *args): + return run_async_celery_task(lambda: _my_task_impl(*args)) + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # Defense-in-depth: prior task may have crashed before + # disposing. Idempotent — no-op if pool is already empty. + _dispose_shared_db_engine(loop) + return loop.run_until_complete(coro_factory()) + finally: + # Drop any connections this task opened so they don't leak + # into the next task's loop. + _dispose_shared_db_engine(loop) + with contextlib.suppress(Exception): + loop.run_until_complete(loop.shutdown_asyncgens()) + with contextlib.suppress(Exception): + asyncio.set_event_loop(None) + loop.close() + + +__all__ = [ + "get_celery_session_maker", + "run_async_celery_task", +] diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index fe1ac19d3..08d96cfa0 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -4,7 +4,7 @@ import logging import traceback from app.celery_app import celery_app -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -49,22 +49,15 @@ def index_notion_pages_task( end_date: str, ): """Celery task to index Notion pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_notion_pages( + return run_async_celery_task( + lambda: _index_notion_pages( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_notion_pages", connector_id) raise - finally: - loop.close() async def _index_notion_pages( @@ -95,19 +88,11 @@ def index_github_repos_task( end_date: str, ): """Celery task to index GitHub repositories.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_github_repos( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_github_repos( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_github_repos( @@ -138,19 +123,11 @@ def index_confluence_pages_task( end_date: str, ): """Celery task to index Confluence pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_confluence_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_confluence_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_confluence_pages( @@ -181,22 +158,15 @@ def index_google_calendar_events_task( end_date: str, ): """Celery task to index Google Calendar events.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_google_calendar_events( + return run_async_celery_task( + lambda: _index_google_calendar_events( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_google_calendar_events", connector_id) raise - finally: - loop.close() async def _index_google_calendar_events( @@ -227,19 +197,11 @@ def index_google_gmail_messages_task( end_date: str, ): """Celery task to index Google Gmail messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_gmail_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_google_gmail_messages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_google_gmail_messages( @@ -269,22 +231,14 @@ def index_google_drive_files_task( items_dict: dict, # Dictionary with 'folders', 'files', and 'indexing_options' ): """Celery task to index Google Drive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_google_drive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_google_drive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_google_drive_files( @@ -317,22 +271,14 @@ def index_onedrive_files_task( items_dict: dict, ): """Celery task to index OneDrive folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_onedrive_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_onedrive_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_onedrive_files( @@ -365,22 +311,14 @@ def index_dropbox_files_task( items_dict: dict, ): """Celery task to index Dropbox folders and files.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_dropbox_files( - connector_id, - search_space_id, - user_id, - items_dict, - ) + return run_async_celery_task( + lambda: _index_dropbox_files( + connector_id, + search_space_id, + user_id, + items_dict, ) - finally: - loop.close() + ) async def _index_dropbox_files( @@ -414,19 +352,11 @@ def index_elasticsearch_documents_task( end_date: str, ): """Celery task to index Elasticsearch documents.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_elasticsearch_documents( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_elasticsearch_documents( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_elasticsearch_documents( @@ -457,22 +387,15 @@ def index_crawled_urls_task( end_date: str, ): """Celery task to index Web page Urls.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_crawled_urls( + return run_async_celery_task( + lambda: _index_crawled_urls( connector_id, search_space_id, user_id, start_date, end_date ) ) except Exception as e: _handle_greenlet_error(e, "index_crawled_urls", connector_id) raise - finally: - loop.close() async def _index_crawled_urls( @@ -503,19 +426,11 @@ def index_bookstack_pages_task( end_date: str, ): """Celery task to index BookStack pages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_bookstack_pages( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_bookstack_pages( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_bookstack_pages( @@ -546,19 +461,11 @@ def index_composio_connector_task( end_date: str | None, ): """Celery task to index Composio connector content (Google Drive, Gmail, Calendar via Composio).""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_composio_connector( - connector_id, search_space_id, user_id, start_date, end_date - ) + return run_async_celery_task( + lambda: _index_composio_connector( + connector_id, search_space_id, user_id, start_date, end_date ) - finally: - loop.close() + ) async def _index_composio_connector( diff --git a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py index c2dbe7700..5d6bde6c1 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_reindex_tasks.py @@ -11,7 +11,7 @@ from app.db import Document from app.indexing_pipeline.adapters.file_upload_adapter import UploadDocumentAdapter from app.services.llm_service import get_user_long_context_llm from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -25,15 +25,7 @@ def reindex_document_task(self, document_id: int, user_id: str): document_id: ID of document to reindex user_id: ID of user who edited the document """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reindex_document(document_id, user_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _reindex_document(document_id, user_id)) async def _reindex_document(document_id: int, user_id: str): diff --git a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py index 9d12f91f6..c78e376bd 100644 --- a/surfsense_backend/app/tasks/celery_tasks/document_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/document_tasks.py @@ -11,7 +11,7 @@ from app.celery_app import celery_app from app.config import config from app.services.notification_service import NotificationService from app.services.task_logging_service import TaskLoggingService -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.tasks.connector_indexers.local_folder_indexer import ( index_local_folder, index_uploaded_files, @@ -105,12 +105,7 @@ async def _run_heartbeat_loop(notification_id: int): ) def delete_document_task(self, document_id: int): """Celery task to delete a document and its chunks in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_document_background(document_id)) - finally: - loop.close() + return run_async_celery_task(lambda: _delete_document_background(document_id)) async def _delete_document_background(document_id: int) -> None: @@ -153,14 +148,9 @@ def delete_folder_documents_task( folder_subtree_ids: list[int] | None = None, ): """Celery task to delete documents first, then the folder rows.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _delete_folder_documents(document_ids, folder_subtree_ids) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_folder_documents(document_ids, folder_subtree_ids) + ) async def _delete_folder_documents( @@ -209,12 +199,9 @@ async def _delete_folder_documents( ) def delete_search_space_task(self, search_space_id: int): """Celery task to delete a search space and heavy child rows in batches.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_delete_search_space_background(search_space_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _delete_search_space_background(search_space_id) + ) async def _delete_search_space_background(search_space_id: int) -> None: @@ -269,18 +256,11 @@ def process_extension_document_task( search_space_id: ID of the search space user_id: ID of the user """ - # Create a new event loop for this task - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_extension_document( - individual_document_dict, search_space_id, user_id - ) + return run_async_celery_task( + lambda: _process_extension_document( + individual_document_dict, search_space_id, user_id ) - finally: - loop.close() + ) async def _process_extension_document( @@ -419,13 +399,9 @@ def process_youtube_video_task(self, url: str, search_space_id: int, user_id: st search_space_id: ID of the search space user_id: ID of the user """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_process_youtube_video(url, search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _process_youtube_video(url, search_space_id, user_id) + ) async def _process_youtube_video(url: str, search_space_id: int, user_id: str): @@ -573,12 +549,9 @@ def process_file_upload_task( except Exception as e: logger.warning(f"[process_file_upload] Could not get file size: {e}") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_upload(file_path, filename, search_space_id, user_id) + run_async_celery_task( + lambda: _process_file_upload(file_path, filename, search_space_id, user_id) ) logger.info( f"[process_file_upload] Task completed successfully for: {filename}" @@ -589,8 +562,6 @@ def process_file_upload_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _process_file_upload( @@ -811,25 +782,17 @@ def process_file_upload_with_document_task( "File may have been removed before syncing could start." ) # Mark document as failed since file is missing - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _mark_document_failed( - document_id, - "File not found. Please re-upload the file.", - ) + run_async_celery_task( + lambda: _mark_document_failed( + document_id, + "File not found. Please re-upload the file.", ) - finally: - loop.close() + ) return - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _process_file_with_document( + run_async_celery_task( + lambda: _process_file_with_document( document_id, temp_path, filename, @@ -849,8 +812,6 @@ def process_file_upload_with_document_task( f"Traceback:\n{traceback.format_exc()}" ) raise - finally: - loop.close() async def _mark_document_failed(document_id: int, reason: str): @@ -1119,22 +1080,16 @@ def process_circleback_meeting_task( search_space_id: ID of the search space connector_id: ID of the Circleback connector (for deletion support) """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _process_circleback_meeting( - meeting_id, - meeting_name, - markdown_content, - metadata, - search_space_id, - connector_id, - ) + return run_async_celery_task( + lambda: _process_circleback_meeting( + meeting_id, + meeting_name, + markdown_content, + metadata, + search_space_id, + connector_id, ) - finally: - loop.close() + ) async def _process_circleback_meeting( @@ -1291,25 +1246,19 @@ def index_local_folder_task( target_file_paths: list[str] | None = None, ): """Celery task to index a local folder. Config is passed directly — no connector row.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_local_folder_async( - search_space_id=search_space_id, - user_id=user_id, - folder_path=folder_path, - folder_name=folder_name, - exclude_patterns=exclude_patterns, - file_extensions=file_extensions, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - target_file_paths=target_file_paths, - ) + return run_async_celery_task( + lambda: _index_local_folder_async( + search_space_id=search_space_id, + user_id=user_id, + folder_path=folder_path, + folder_name=folder_name, + exclude_patterns=exclude_patterns, + file_extensions=file_extensions, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + target_file_paths=target_file_paths, ) - finally: - loop.close() + ) async def _index_local_folder_async( @@ -1441,23 +1390,18 @@ def index_uploaded_folder_files_task( processing_mode: str = "basic", ): """Celery task to index files uploaded from the desktop app.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_uploaded_folder_files_async( - search_space_id=search_space_id, - user_id=user_id, - folder_name=folder_name, - root_folder_id=root_folder_id, - enable_summary=enable_summary, - file_mappings=file_mappings, - use_vision_llm=use_vision_llm, - processing_mode=processing_mode, - ) + return run_async_celery_task( + lambda: _index_uploaded_folder_files_async( + search_space_id=search_space_id, + user_id=user_id, + folder_name=folder_name, + root_folder_id=root_folder_id, + enable_summary=enable_summary, + file_mappings=file_mappings, + use_vision_llm=use_vision_llm, + processing_mode=processing_mode, ) - finally: - loop.close() + ) async def _index_uploaded_folder_files_async( @@ -1584,12 +1528,9 @@ def _ai_sort_lock_key(search_space_id: int) -> str: @celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1) def ai_sort_search_space_task(self, search_space_id: int, user_id: str): """Full AI sort for all documents in a search space.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id)) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_search_space_async(search_space_id, user_id) + ) async def _ai_sort_search_space_async(search_space_id: int, user_id: str): @@ -1639,14 +1580,9 @@ async def _ai_sort_search_space_async(search_space_id: int, user_id: str): ) def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int): """Incremental AI sort for a single document after indexing.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _ai_sort_document_async(search_space_id, user_id, document_id) - ) - finally: - loop.close() + return run_async_celery_task( + lambda: _ai_sort_document_async(search_space_id, user_id, document_id) + ) async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int): diff --git a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py index 98b107af3..c6c8666f5 100644 --- a/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/obsidian_tasks.py @@ -2,14 +2,13 @@ from __future__ import annotations -import asyncio import logging from app.celery_app import celery_app from app.db import SearchSourceConnector from app.schemas.obsidian_plugin import NotePayload from app.services.obsidian_plugin_indexer import upsert_note -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -22,18 +21,13 @@ def index_obsidian_attachment_task( user_id: str, ) -> None: """Process one Obsidian non-markdown attachment asynchronously.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - _index_obsidian_attachment( - connector_id=connector_id, - payload_data=payload_data, - user_id=user_id, - ) + return run_async_celery_task( + lambda: _index_obsidian_attachment( + connector_id=connector_id, + payload_data=payload_data, + user_id=user_id, ) - finally: - loop.close() + ) async def _index_obsidian_attachment( diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 937877473..8b311576e 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -3,6 +3,7 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select @@ -12,11 +13,12 @@ from app.celery_app import celery_app from app.config import config as app_config from app.db import Podcast, PodcastStatus from app.services.billable_calls import ( + BillingSettlementError, QuotaInsufficientError, _resolve_agent_billing_for_search_space, billable_call, ) -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -34,6 +36,13 @@ if sys.platform.startswith("win"): # ============================================================================= +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_content_podcast", bind=True) def generate_content_podcast_task( self, @@ -46,27 +55,22 @@ def generate_content_podcast_task( Celery task to generate podcast from source content. Updates existing podcast record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_content_podcast( + return run_async_celery_task( + lambda: _generate_content_podcast( podcast_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating content podcast: {e!s}") - loop.run_until_complete(_mark_podcast_failed(podcast_id)) + try: + run_async_celery_task(lambda: _mark_podcast_failed(podcast_id)) + except Exception: + logger.exception("Failed to mark podcast %s as failed", podcast_id) return {"status": "failed", "podcast_id": podcast_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_podcast_failed(podcast_id: int) -> None: @@ -148,11 +152,12 @@ async def _generate_content_podcast( base_model=base_model, quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS, usage_type="podcast_generation", - thread_id=podcast.thread_id, call_details={ "podcast_id": podcast.id, "title": podcast.title, + "thread_id": podcast.thread_id, }, + billable_session_factory=_celery_billable_session, ): graph_result = await podcaster_graph.ainvoke( initial_state, config=graph_config @@ -173,6 +178,18 @@ async def _generate_content_podcast( "podcast_id": podcast.id, "reason": "premium_quota_exhausted", } + except BillingSettlementError: + logger.exception( + "Podcast %s: premium billing settlement failed", + podcast.id, + ) + podcast.status = PodcastStatus.FAILED + await session.commit() + return { + "status": "failed", + "podcast_id": podcast.id, + "reason": "billing_settlement_failed", + } podcast_transcript = graph_result.get("podcast_transcript", []) file_path = graph_result.get("final_podcast_file_path", "") @@ -194,7 +211,14 @@ async def _generate_content_podcast( podcast.podcast_transcript = serializable_transcript podcast.file_location = file_path podcast.status = PodcastStatus.READY + logger.info( + "Podcast %s: committing READY transcript_entries=%d file=%s", + podcast.id, + len(serializable_transcript), + file_path, + ) await session.commit() + logger.info("Podcast %s: READY commit complete", podcast.id) logger.info(f"Successfully generated podcast: {podcast.id}") diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index 373f04b48..e41251407 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -7,7 +7,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.db import Notification, SearchSourceConnector, SearchSourceConnectorType -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task from app.utils.indexing_locks import is_connector_indexing_locked logger = logging.getLogger(__name__) @@ -20,15 +20,7 @@ def check_periodic_schedules_task(): This task runs every minute and triggers indexing for any connector whose next_scheduled_at time has passed. """ - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_check_and_trigger_schedules()) - finally: - loop.close() + return run_async_celery_task(_check_and_trigger_schedules) async def _check_and_trigger_schedules(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py index e05ae9435..d51c85dee 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stale_notification_cleanup_task.py @@ -34,7 +34,7 @@ from sqlalchemy.future import select from app.celery_app import celery_app from app.config import config from app.db import Document, DocumentStatus, Notification -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -69,16 +69,12 @@ def cleanup_stale_indexing_notifications_task(): Detection: Redis heartbeat key with 2-min TTL. Missing key = stale task. Also marks associated pending/processing documents as failed. """ - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + async def _both() -> None: + await _cleanup_stale_notifications() + await _cleanup_stale_document_processing_notifications() - try: - loop.run_until_complete(_cleanup_stale_notifications()) - loop.run_until_complete(_cleanup_stale_document_processing_notifications()) - finally: - loop.close() + return run_async_celery_task(_both) async def _cleanup_stale_notifications(): diff --git a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py index 3aee1a360..ace6ef7ca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/stripe_reconciliation_task.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import logging from datetime import UTC, datetime, timedelta @@ -18,7 +17,7 @@ from app.db import ( PremiumTokenPurchaseStatus, ) from app.routes import stripe_routes -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -36,13 +35,7 @@ def get_stripe_client() -> StripeClient | None: @celery_app.task(name="reconcile_pending_stripe_page_purchases") def reconcile_pending_stripe_page_purchases_task(): """Recover paid purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_page_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_page_purchases) async def _reconcile_pending_page_purchases() -> None: @@ -141,13 +134,7 @@ async def _reconcile_pending_page_purchases() -> None: @celery_app.task(name="reconcile_pending_stripe_token_purchases") def reconcile_pending_stripe_token_purchases_task(): """Recover paid token purchases that were left pending due to missed webhook handling.""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(_reconcile_pending_token_purchases()) - finally: - loop.close() + return run_async_celery_task(_reconcile_pending_token_purchases) async def _reconcile_pending_token_purchases() -> None: diff --git a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py index 4f0c427d9..08f22140c 100644 --- a/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/video_presentation_tasks.py @@ -3,6 +3,7 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from sqlalchemy import select @@ -12,11 +13,12 @@ from app.celery_app import celery_app from app.config import config as app_config from app.db import VideoPresentation, VideoPresentationStatus from app.services.billable_calls import ( + BillingSettlementError, QuotaInsufficientError, _resolve_agent_billing_for_search_space, billable_call, ) -from app.tasks.celery_tasks import get_celery_session_maker +from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task logger = logging.getLogger(__name__) @@ -29,6 +31,13 @@ if sys.platform.startswith("win"): ) +@asynccontextmanager +async def _celery_billable_session(): + """Session factory used by billable_call inside the Celery worker loop.""" + async with get_celery_session_maker()() as session: + yield session + + @celery_app.task(name="generate_video_presentation", bind=True) def generate_video_presentation_task( self, @@ -41,27 +50,30 @@ def generate_video_presentation_task( Celery task to generate video presentation from source content. Updates existing video presentation record created by the tool. """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete( - _generate_video_presentation( + return run_async_celery_task( + lambda: _generate_video_presentation( video_presentation_id, source_content, search_space_id, user_prompt, ) ) - loop.run_until_complete(loop.shutdown_asyncgens()) - return result except Exception as e: logger.error(f"Error generating video presentation: {e!s}") - loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id)) + # Mark FAILED in a fresh loop — the previous loop is closed. + # Swallow secondary failures; the row will simply stay in + # GENERATING and be flushed by the periodic stale cleanup. + try: + run_async_celery_task( + lambda: _mark_video_presentation_failed(video_presentation_id) + ) + except Exception: + logger.exception( + "Failed to mark video presentation %s as failed", + video_presentation_id, + ) return {"status": "failed", "video_presentation_id": video_presentation_id} - finally: - asyncio.set_event_loop(None) - loop.close() async def _mark_video_presentation_failed(video_presentation_id: int) -> None: @@ -150,11 +162,12 @@ async def _generate_video_presentation( base_model=base_model, quota_reserve_micros_override=app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS, usage_type="video_presentation_generation", - thread_id=video_pres.thread_id, call_details={ "video_presentation_id": video_pres.id, "title": video_pres.title, + "thread_id": video_pres.thread_id, }, + billable_session_factory=_celery_billable_session, ): graph_result = await video_presentation_graph.ainvoke( initial_state, config=graph_config @@ -175,6 +188,18 @@ async def _generate_video_presentation( "video_presentation_id": video_pres.id, "reason": "premium_quota_exhausted", } + except BillingSettlementError: + logger.exception( + "VideoPresentation %s: premium billing settlement failed", + video_pres.id, + ) + video_pres.status = VideoPresentationStatus.FAILED + await session.commit() + return { + "status": "failed", + "video_presentation_id": video_pres.id, + "reason": "billing_settlement_failed", + } # Serialize slides (parsed content + audio info merged) slides_raw = graph_result.get("slides", []) @@ -205,7 +230,14 @@ async def _generate_video_presentation( video_pres.slides = serializable_slides video_pres.scene_codes = serializable_scene_codes video_pres.status = VideoPresentationStatus.READY + logger.info( + "VideoPresentation %s: committing READY slides=%d scene_codes=%d", + video_pres.id, + len(serializable_slides), + len(serializable_scene_codes), + ) await session.commit() + logger.info("VideoPresentation %s: READY commit complete", video_pres.id) logger.info(f"Successfully generated video presentation: {video_pres.id}") diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 31c0d7d6d..c6ac3311a 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -1506,10 +1506,10 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else "Podcast" ) - if podcast_status == "processing": + if podcast_status in ("pending", "generating", "processing"): completed_items = [ f"Title: {podcast_title}", - "Audio generation started", + "Podcast generation started", "Processing in background...", ] elif podcast_status == "already_generating": @@ -1518,7 +1518,7 @@ async def _stream_agent_events( "Podcast already in progress", "Please wait for it to complete", ] - elif podcast_status == "error": + elif podcast_status in ("failed", "error"): error_msg = ( tool_output.get("error", "Unknown error") if isinstance(tool_output, dict) @@ -1528,6 +1528,11 @@ async def _stream_agent_events( f"Title: {podcast_title}", f"Error: {error_msg[:50]}", ] + elif podcast_status in ("ready", "success"): + completed_items = [ + f"Title: {podcast_title}", + "Podcast ready", + ] else: completed_items = last_active_step_items yield streaming_service.format_thinking_step( @@ -1710,20 +1715,28 @@ async def _stream_agent_events( if isinstance(tool_output, dict) else {"result": tool_output}, ) - if ( - isinstance(tool_output, dict) - and tool_output.get("status") == "success" + if isinstance(tool_output, dict) and tool_output.get("status") in ( + "pending", + "generating", + "processing", + ): + yield streaming_service.format_terminal_info( + f"Podcast queued: {tool_output.get('title', 'Podcast')}", + "success", + ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "ready", + "success", ): yield streaming_service.format_terminal_info( f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}", "success", ) - else: - error_msg = ( - tool_output.get("error", "Unknown error") - if isinstance(tool_output, dict) - else "Unknown error" - ) + elif isinstance(tool_output, dict) and tool_output.get("status") in ( + "failed", + "error", + ): + error_msg = tool_output.get("error", "Unknown error") yield streaming_service.format_terminal_info( f"Podcast generation failed: {error_msg}", "error", @@ -2292,6 +2305,11 @@ async def stream_new_chat( ) _t0 = time.perf_counter() + # Image-bearing turns force the Auto-pin resolver to filter the + # candidate pool to vision-capable cfgs (and force-repin a + # text-only existing pin). For explicit selections this flag is + # a no-op — the resolver returns the user's chosen id unchanged. + _requires_image_input = bool(user_image_data_urls) try: llm_config_id = ( await resolve_or_get_pinned_llm_config_id( @@ -2300,13 +2318,29 @@ async def stream_new_chat( search_space_id=search_space_id, user_id=user_id, selected_llm_config_id=llm_config_id, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: + # Auto-pin's "no vision-capable cfg" path raises a ValueError + # whose message we map to the friendly image-input SSE error + # so the user sees the same message regardless of whether + # the gate fired in Auto-mode or in the agent_config check + # below. + error_code = ( + "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + if _requires_image_input and "vision-capable" in str(pin_error) + else "SERVER_ERROR" + ) + error_kind = ( + "user_error" + if error_code == "MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT" + else "server_error" + ) yield _emit_stream_error( message=str(pin_error), - error_kind="server_error", - error_code="SERVER_ERROR", + error_kind=error_kind, + error_code=error_code, ) yield streaming_service.format_done() return @@ -2326,6 +2360,50 @@ async def stream_new_chat( llm_config_id, ) + # Capability safety net: a turn carrying user-uploaded images + # cannot be routed to a chat config that LiteLLM's authoritative + # model map *explicitly* marks as text-only (``supports_vision`` + # set to False). The check is intentionally narrow — it only + # fires when LiteLLM is *certain* the model can't accept image + # input. Unknown / unmapped / vision-capable models pass + # through. Without this guard a known-text-only model would 404 + # at the provider with ``"No endpoints found that support image + # input"``, surfacing as an opaque ``SERVER_ERROR`` SSE chunk; + # failing here lets us return a friendly message that tells the + # user what to change. + if user_image_data_urls and agent_config is not None: + from app.services.provider_capabilities import ( + is_known_text_only_chat_model, + ) + + agent_litellm_params = agent_config.litellm_params or {} + agent_base_model = ( + agent_litellm_params.get("base_model") + if isinstance(agent_litellm_params, dict) + else None + ) + if is_known_text_only_chat_model( + provider=agent_config.provider, + model_name=agent_config.model_name, + base_model=agent_base_model, + custom_provider=agent_config.custom_provider, + ): + model_label = ( + agent_config.config_name or agent_config.model_name or "model" + ) + yield _emit_stream_error( + message=( + f"The selected model ({model_label}) does not support " + "image input. Switch to a vision-capable model " + "(e.g. GPT-4o, Claude, Gemini) or remove the image " + "attachment and try again." + ), + error_kind="user_error", + error_code="MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT", + ) + yield streaming_service.format_done() + return + # Premium quota reservation for pinned premium model only. _needs_premium_quota = ( agent_config is not None and user_id and agent_config.is_premium @@ -2366,6 +2444,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, force_repin_free=True, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: @@ -2470,6 +2549,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id except ValueError as pin_error: @@ -2804,6 +2884,7 @@ async def stream_new_chat( from litellm import acompletion from app.services.llm_router_service import LLMRouterService + from app.services.provider_api_base import resolve_api_base from app.services.token_tracking_service import _turn_accumulator _turn_accumulator.set(None) @@ -2824,11 +2905,32 @@ async def stream_new_chat( model="auto", messages=messages ) else: + # Apply the same ``api_base`` cascade chat / vision / + # image-gen call sites use so we never inherit + # ``litellm.api_base`` (commonly set by + # ``AZURE_OPENAI_ENDPOINT``) when the chat config + # itself ships an empty ``api_base``. Without this + # the title-gen on an OpenRouter chat config would + # 404 against the inherited Azure endpoint — see + # ``provider_api_base`` docstring for the same + # bug repro on the image-gen / vision paths. + raw_model = getattr(llm, "model", "") or "" + provider_prefix = ( + raw_model.split("/", 1)[0] if "/" in raw_model else None + ) + provider_value = ( + agent_config.provider if agent_config is not None else None + ) + title_api_base = resolve_api_base( + provider=provider_value, + provider_prefix=provider_prefix, + config_api_base=getattr(llm, "api_base", None), + ) response = await acompletion( - model=llm.model, + model=raw_model, messages=messages, api_key=getattr(llm, "api_key", None), - api_base=getattr(llm, "api_base", None), + api_base=title_api_base, ) usage_info = None @@ -2953,6 +3055,7 @@ async def stream_new_chat( user_id=user_id, selected_llm_config_id=0, exclude_config_ids={previous_config_id}, + requires_image_input=_requires_image_input, ) ).resolved_llm_config_id diff --git a/surfsense_backend/scripts/verify_chat_image_capability.py b/surfsense_backend/scripts/verify_chat_image_capability.py new file mode 100644 index 000000000..a49d4eab2 --- /dev/null +++ b/surfsense_backend/scripts/verify_chat_image_capability.py @@ -0,0 +1,558 @@ +"""End-to-end smoke test for vision / image config wiring. + +Loads the live ``global_llm_config.yaml`` (no mocking, no fixtures) and +exercises every chat / vision / image-generation config + the OpenRouter +dynamic catalog. For each config the script: + +1. Reports the resolver classification (catalog-allow vs strict-block). +2. Optionally fires a tiny live API call against the provider: + - Chat configs: ``litellm.acompletion`` with a 1x1 PNG and the prompt + ``"reply with one word: ok"``. + - Vision configs: same, against the dedicated vision router pool. + - Image-gen configs: ``litellm.aimage_generation`` with a single tiny + prompt and ``n=1``. + - OpenRouter integration: samples one chat, one vision, one image-gen + model from the dynamically fetched catalog. + +Usage:: + + python -m scripts.verify_chat_image_capability # capability + connectivity + python -m scripts.verify_chat_image_capability --no-live # capability resolver only + +The script is meant to be runnable from the repository root or from +``surfsense_backend/`` and prints a short PASS/FAIL/SKIP summary at the +end so it's usable as a CI smoke check too. + +Live-mode caveat: each successful call costs a small amount of provider +credit (a few tokens or one tiny generated image per config). The +default size for image generation is ``1024x1024`` because Azure +GPT-image deployments reject smaller sizes; OpenRouter image-gen models +generally accept the same size. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Any + +# Bootstrap the surfsense_backend package on sys.path so the script runs +# from the repo root or from `surfsense_backend/` interchangeably. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BACKEND_ROOT = os.path.dirname(_HERE) +if _BACKEND_ROOT not in sys.path: + sys.path.insert(0, _BACKEND_ROOT) + +import litellm # noqa: E402 + +from app.config import config # noqa: E402 +from app.services.openrouter_integration_service import ( # noqa: E402 + _OPENROUTER_DYNAMIC_MARKER, + OpenRouterIntegrationService, +) +from app.services.provider_api_base import resolve_api_base # noqa: E402 +from app.services.provider_capabilities import ( # noqa: E402 + derive_supports_image_input, + is_known_text_only_chat_model, +) + +logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", +) +# Quiet down LiteLLM's verbose router/cost logs so the script output is +# scannable. +logging.getLogger("LiteLLM").setLevel(logging.ERROR) +logging.getLogger("litellm").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) + +# 1x1 transparent PNG — used as the cheapest possible vision payload. +_TINY_PNG_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" +_TINY_PNG_DATA_URL = f"data:image/png;base64,{_TINY_PNG_B64}" + + +# --------------------------------------------------------------------------- +# Result accounting +# --------------------------------------------------------------------------- + + +@dataclass +class ProbeResult: + label: str + surface: str + config_id: int | str + capability_ok: bool | None = None + capability_note: str = "" + live_ok: bool | None = None + live_note: str = "" + duration_s: float = 0.0 + + +@dataclass +class Report: + results: list[ProbeResult] = field(default_factory=list) + + def add(self, r: ProbeResult) -> None: + self.results.append(r) + + def render(self) -> int: + passed = failed = skipped = 0 + print() + print("=" * 92) + print( + f"{'Surface':<14}{'ID':>8} {'Cap':>5} {'Live':>5} {'Time':>6} Label / notes" + ) + print("-" * 92) + for r in self.results: + + def _flag(value: bool | None) -> str: + if value is None: + return "skip" + return "ok" if value else "fail" + + cap = _flag(r.capability_ok) + live = _flag(r.live_ok) + if r.capability_ok is False or r.live_ok is False: + failed += 1 + elif r.capability_ok is None and r.live_ok is None: + skipped += 1 + else: + passed += 1 + print( + f"{r.surface:<14}{r.config_id!s:>8} {cap:>5} {live:>5} " + f"{r.duration_s:>5.2f}s {r.label}" + ) + if r.capability_note: + print(f" cap: {r.capability_note}") + if r.live_note: + print(f" live: {r.live_note}") + print("-" * 92) + print( + f"Total: {passed} ok / {failed} fail / {skipped} skip " + f"(of {len(self.results)} probes)" + ) + print("=" * 92) + return failed + + +# --------------------------------------------------------------------------- +# Capability probes (no network) +# --------------------------------------------------------------------------- + + +def _probe_chat_capability(cfg: dict) -> tuple[bool, str]: + """For chat configs the catalog flag is *expected* True (vision-capable + pool). The probe reports both the resolver value and the strict + safety-net value to surface any drift between them.""" + litellm_params = cfg.get("litellm_params") or {} + base_model = ( + litellm_params.get("base_model") if isinstance(litellm_params, dict) else None + ) + cap = derive_supports_image_input( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + block = is_known_text_only_chat_model( + provider=cfg.get("provider"), + model_name=cfg.get("model_name"), + base_model=base_model, + custom_provider=cfg.get("custom_provider"), + ) + note = f"derive={cap} strict_block={block}" + if not cap and not block: + # Resolver said False but strict gate is also False — that means + # OR modalities published [text] explicitly. Surface it. + note += " (OR modality says text-only)" + # We accept a True derive *or* (False derive AND False block) as + # 'capability ok' — either way, the streaming task will flow through. + ok = cap or not block + return ok, note + + +def _build_chat_model_string(cfg: dict) -> str: + if cfg.get("custom_provider"): + return f"{cfg['custom_provider']}/{cfg['model_name']}" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + return f"{prefix}/{cfg['model_name']}" + + +# --------------------------------------------------------------------------- +# Live probes (network calls) +# --------------------------------------------------------------------------- + + +async def _live_chat_image_call(cfg: dict) -> tuple[bool, str]: + """Send a 1x1 PNG + `reply with one word: ok` to the chat config.""" + model_string = _build_chat_model_string(cfg) + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=model_string.split("/", 1)[0], + config_api_base=cfg.get("api_base") or None, + ) + kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "reply with one word: ok"}, + { + "type": "image_url", + "image_url": {"url": _TINY_PNG_DATA_URL}, + }, + ], + } + ], + "max_tokens": 16, + "timeout": 60, + } + if api_base: + kwargs["api_base"] = api_base + if cfg.get("litellm_params"): + # Strip pricing keys — they're tracking-only and confuse some + # provider validators (e.g. azure/openai reject unknown kwargs + # in strict mode). + merged = { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + kwargs.update(merged) + try: + resp = await litellm.acompletion(**kwargs) + except Exception as exc: + return False, f"{type(exc).__name__}: {exc}" + text = resp.choices[0].message.content if resp.choices else "" + return True, f"got reply ({(text or '').strip()[:40]!r})" + + +# Gemini image models occasionally return zero-length ``data`` for the +# minimal "red dot on white" prompt (provider-side safety / empty-output +# quirk reproducible against ``google/gemini-2.5-flash-image`` even when +# the request itself succeeds). Use a more naturalistic prompt and +# retry once with a different one before giving up. +_IMAGE_GEN_PROMPTS: tuple[str, ...] = ( + "A simple icon of a coffee cup, flat illustration", + "A small green leaf on a white background", +) + + +async def _live_image_gen_call(cfg: dict) -> tuple[bool, str]: + """Generate one tiny image to verify the deployment is reachable.""" + from app.services.provider_capabilities import _PROVIDER_PREFIX_MAP + + if cfg.get("custom_provider"): + prefix = cfg["custom_provider"] + else: + prefix = _PROVIDER_PREFIX_MAP.get( + (cfg.get("provider") or "").upper(), (cfg.get("provider") or "").lower() + ) + model_string = f"{prefix}/{cfg['model_name']}" + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=prefix, + config_api_base=cfg.get("api_base") or None, + ) + base_kwargs: dict[str, Any] = { + "model": model_string, + "api_key": cfg.get("api_key"), + "n": 1, + "size": "1024x1024", + "timeout": 120, + } + if api_base: + base_kwargs["api_base"] = api_base + if cfg.get("api_version"): + base_kwargs["api_version"] = cfg["api_version"] + if cfg.get("litellm_params"): + base_kwargs.update( + { + k: v + for k, v in dict(cfg["litellm_params"]).items() + if k + not in { + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_pixel", + "output_cost_per_pixel", + } + } + ) + + last_note = "" + for attempt, prompt in enumerate(_IMAGE_GEN_PROMPTS, start=1): + try: + resp = await litellm.aimage_generation(prompt=prompt, **base_kwargs) + except Exception as exc: + last_note = f"{type(exc).__name__}: {exc}" + continue + data_count = len(getattr(resp, "data", None) or []) + if data_count > 0: + return True, ( + f"received {data_count} image(s) on attempt {attempt} " + f"(prompt={prompt!r})" + ) + last_note = ( + f"call ok but received 0 images on attempt {attempt} (prompt={prompt!r})" + ) + return False, last_note + + +# --------------------------------------------------------------------------- +# Probe drivers +# --------------------------------------------------------------------------- + + +def _is_or_dynamic(cfg: dict) -> bool: + return bool(cfg.get(_OPENROUTER_DYNAMIC_MARKER)) + + +async def probe_chat_configs(report: Report, *, live: bool) -> None: + print("\n[chat configs from global_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_LLM_CONFIGS: + # Skip OR dynamic entries here — handled in the OR section so + # the YAML / OR split stays clear in the report. + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="chat-yaml", + config_id=cfg.get("id"), + ) + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_vision_configs(report: Report, *, live: bool) -> None: + print("\n[vision configs from global_vision_llm_configs (YAML-static)]") + for cfg in config.GLOBAL_VISION_LLM_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="vision", + config_id=cfg.get("id"), + ) + # For vision configs, capability is implied — they're in the + # dedicated vision pool. Run the same resolver to flag any + # surprise disagreement. + cap_ok, cap_note = _probe_chat_capability(cfg) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await _live_chat_image_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_image_gen_configs(report: Report, *, live: bool) -> None: + print( + "\n[image generation configs from global_image_generation_configs (YAML-static)]" + ) + for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: + if _is_or_dynamic(cfg): + continue + result = ProbeResult( + label=str(cfg.get("name") or cfg.get("model_name")), + surface="image-gen", + config_id=cfg.get("id"), + ) + # Image gen configs don't have a "supports_image_input" flag; + # the catalog tracks output, not input. Mark capability as None + # (skip) for the report. + if live: + t0 = time.perf_counter() + ok, note = await _live_image_gen_call(cfg) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +async def probe_openrouter_catalog(report: Report, *, live: bool) -> None: + """Sample one chat (vision-capable), one vision, one image-gen model + from the live OpenRouter catalogue. Doesn't iterate the full pool + (would be hundreds of probes); just validates the integration end- + to-end on a representative model from each surface.""" + print("\n[OpenRouter integration: sampled probes]") + settings = config.OPENROUTER_INTEGRATION_SETTINGS + if not settings: + report.add( + ProbeResult( + label="OpenRouter integration", + surface="openrouter", + config_id="settings", + capability_ok=None, + capability_note="openrouter_integration disabled in YAML — skipping", + live_ok=None, + ) + ) + return + + service = OpenRouterIntegrationService.get_instance() + or_chat = [ + c + for c in config.GLOBAL_LLM_CONFIGS + if c.get("provider") == "OPENROUTER" and c.get("supports_image_input") + ] + or_vision = [ + c for c in config.GLOBAL_VISION_LLM_CONFIGS if c.get("provider") == "OPENROUTER" + ] + or_image_gen = [ + c for c in config.GLOBAL_IMAGE_GEN_CONFIGS if c.get("provider") == "OPENROUTER" + ] + + # Pick one representative per provider family per surface so a single + # broken vendor (e.g. Anthropic key revoked, Google quota exceeded) + # surfaces independently of the others. Each needle matches the + # OpenRouter ``model_name`` prefix; the first match wins. + def _pick_first(pool: list[dict], needle: str) -> dict | None: + for c in pool: + if (c.get("model_name") or "").lower().startswith(needle): + return c + return None + + chat_picks = [ + ("or-chat", _pick_first(or_chat, "openai/gpt-4o")), + ("or-chat", _pick_first(or_chat, "anthropic/claude")), + ("or-chat", _pick_first(or_chat, "google/gemini-2.5-flash")), + ] + vision_picks = [ + ("or-vision", _pick_first(or_vision, "openai/gpt-4o")), + ("or-vision", _pick_first(or_vision, "anthropic/claude")), + ("or-vision", _pick_first(or_vision, "google/gemini-2.5-flash")), + ] + image_picks = [ + ("or-image", _pick_first(or_image_gen, "google/gemini-2.5-flash-image")), + # OpenRouter publishes OpenAI image models as ``openai/gpt-5-image*`` + # / ``openai/gpt-5.4-image-2`` (no ``gpt-image`` literal). Match + # the actual prefix. + ("or-image", _pick_first(or_image_gen, "openai/gpt-5-image")), + ] + + print( + f" catalog: chat={len(or_chat)} vision={len(or_vision)} image_gen={len(or_image_gen)} " + f"(service initialized={service.is_initialized() if hasattr(service, 'is_initialized') else 'n/a'})" + ) + + for surface, picked in chat_picks + vision_picks + image_picks: + if not picked: + report.add( + ProbeResult( + label=f"", + surface=surface, + config_id="-", + capability_ok=None, + capability_note="no candidate found in OR catalog", + ) + ) + continue + runner = ( + _live_image_gen_call if surface == "or-image" else _live_chat_image_call + ) + result = ProbeResult( + label=str(picked.get("model_name")), + surface=surface, + config_id=picked.get("id"), + ) + if surface != "or-image": + cap_ok, cap_note = _probe_chat_capability(picked) + result.capability_ok = cap_ok + result.capability_note = cap_note + if live: + t0 = time.perf_counter() + ok, note = await runner(picked) + result.live_ok = ok + result.live_note = note + result.duration_s = time.perf_counter() - t0 + report.add(result) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +async def main(args: argparse.Namespace) -> int: + print("Loaded global configs:") + print(f" chat: {len(config.GLOBAL_LLM_CONFIGS)} entries") + print(f" vision: {len(config.GLOBAL_VISION_LLM_CONFIGS)} entries") + print(f" image-gen: {len(config.GLOBAL_IMAGE_GEN_CONFIGS)} entries") + print(f" OR settings present: {bool(config.OPENROUTER_INTEGRATION_SETTINGS)}") + + # Initialize the OpenRouter integration so the catalog is populated + # (this is what main.py does at startup). It's idempotent. + if config.OPENROUTER_INTEGRATION_SETTINGS: + try: + from app.config import initialize_openrouter_integration + + initialize_openrouter_integration() + except Exception as exc: + print(f" WARNING: OpenRouter integration init failed: {exc}") + + print( + f"\nMode: {'LIVE (will hit providers)' if args.live else 'DRY (capability only)'}" + ) + + report = Report() + if not args.skip_chat: + await probe_chat_configs(report, live=args.live) + if not args.skip_vision: + await probe_vision_configs(report, live=args.live) + if not args.skip_image_gen: + await probe_image_gen_configs(report, live=args.live) + if not args.skip_openrouter: + await probe_openrouter_catalog(report, live=args.live) + + failed = report.render() + return 1 if failed else 0 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--no-live", + dest="live", + action="store_false", + help="Skip live API calls — capability resolver only.", + ) + parser.set_defaults(live=True) + parser.add_argument("--skip-chat", action="store_true") + parser.add_argument("--skip-vision", action="store_true") + parser.add_argument("--skip-image-gen", action="store_true") + parser.add_argument("--skip-openrouter", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + sys.exit(asyncio.run(main(args))) diff --git a/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py new file mode 100644 index 000000000..c9f18d77d --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_byok_supports_image_input.py @@ -0,0 +1,110 @@ +"""Unit tests for ``supports_image_input`` derivation on BYOK chat config +endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``). + +There is no DB column for ``supports_image_input`` on +``NewLLMConfig`` — the value is resolved at the API boundary by +``derive_supports_image_input`` so the new-chat selector / streaming +task can read the same field shape regardless of source (BYOK vs YAML +vs OpenRouter dynamic). Default-allow on unknown so we don't lock the +user out of their own model choice. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from app.db import LiteLLMProvider +from app.routes import new_llm_config_routes + +pytestmark = pytest.mark.unit + + +def _byok_row( + *, + id_: int, + model_name: str, + base_model: str | None = None, + provider: LiteLLMProvider = LiteLLMProvider.OPENAI, + custom_provider: str | None = None, +) -> object: + """Mimic the SQLAlchemy row's attribute surface; ``model_validate`` + walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough. + + ``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's + enum validator accepts it — same as the ORM row would carry.""" + return SimpleNamespace( + id=id_, + name=f"BYOK-{id_}", + description=None, + provider=provider, + custom_provider=custom_provider, + model_name=model_name, + api_key="sk-byok", + api_base=None, + litellm_params={"base_model": base_model} if base_model else None, + system_instructions="", + use_default_system_instructions=True, + citations_enabled=True, + created_at=datetime.now(tz=UTC), + search_space_id=42, + user_id=uuid4(), + ) + + +def test_serialize_byok_known_vision_model_resolves_true(): + """The catalog resolver consults LiteLLM's map for ``gpt-4o`` -> + True. The serialized row carries that value through to the + ``NewLLMConfigRead`` schema.""" + row = _byok_row(id_=1, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + assert serialized.id == 1 + assert serialized.model_name == "gpt-4o" + + +def test_serialize_byok_unknown_model_default_allows(): + """Unknown / unmapped: default-allow. The streaming-task safety net + is the actual block, and it requires LiteLLM to *explicitly* say + text-only — so a brand new BYOK model should not be pre-judged.""" + row = _byok_row( + id_=2, + model_name="brand-new-model-x9-unmapped", + provider=LiteLLMProvider.CUSTOM, + custom_provider="brand_new_proxy", + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_uses_base_model_when_present(): + """Azure-style: ``model_name`` is the deployment id, ``base_model`` + inside ``litellm_params`` is the canonical sku LiteLLM knows. The + helper must consult ``base_model`` first or unrecognised deployment + ids would shadow the real capability.""" + row = _byok_row( + id_=3, + model_name="my-azure-deployment-id-no-litellm-knows-this", + base_model="gpt-4o", + provider=LiteLLMProvider.AZURE_OPENAI, + ) + serialized = new_llm_config_routes._serialize_byok_config(row) + + assert serialized.supports_image_input is True + + +def test_serialize_byok_returns_pydantic_read_model(): + """The route now returns ``NewLLMConfigRead`` (not the raw ORM) so + the schema additions are guaranteed to be present in the API + surface. This guards against a future regression where someone + deletes the augmentation step and falls back to ORM passthrough.""" + from app.schemas import NewLLMConfigRead + + row = _byok_row(id_=4, model_name="gpt-4o") + serialized = new_llm_config_routes._serialize_byok_config(row) + assert isinstance(serialized, NewLLMConfigRead) diff --git a/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py new file mode 100644 index 000000000..2b6c76485 --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_configs_is_premium.py @@ -0,0 +1,184 @@ +"""Unit tests for ``is_premium`` derivation on the global image-gen and +vision-LLM list endpoints. + +Chat globals (``GET /global-llm-configs``) already emit +``is_premium = (billing_tier == "premium")``. Image and vision did not, +which made the new-chat ``model-selector`` render the Free/Premium badge +on the Chat tab but skip it on the Image and Vision tabs (the selector +keys its badge logic off ``is_premium``). These tests pin parity: + +* YAML free entry → ``is_premium=False`` +* YAML premium entry → ``is_premium=True`` +* OpenRouter dynamic premium entry → ``is_premium=True`` +* Auto stub (always emitted when at least one config is present) + → ``is_premium=False`` +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_IMAGE_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "DALL-E 3", + "provider": "OPENAI", + "model_name": "dall-e-3", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "GPT-Image 1 (premium)", + "provider": "OPENAI", + "model_name": "gpt-image-1", + "api_key": "sk-test", + "billing_tier": "premium", + }, + { + "id": -20_001, + "name": "google/gemini-2.5-flash-image (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "google/gemini-2.5-flash-image", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +_VISION_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o Vision", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + }, + { + "id": -2, + "name": "Claude 3.5 Sonnet (premium)", + "provider": "ANTHROPIC", + "model_name": "claude-3-5-sonnet", + "api_key": "sk-ant-test", + "billing_tier": "premium", + }, + { + "id": -30_001, + "name": "openai/gpt-4o (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "premium", + }, +] + + +# ============================================================================= +# Image generation +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_emit_is_premium(monkeypatch): + """Each emitted config must carry ``is_premium`` derived server-side + from ``billing_tier``. The Auto stub is always free. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr( + config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False + ) + + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + # Auto stub is always emitted when at least one global config exists, + # and it must always declare itself free (Auto-mode billing-tier + # surfacing is a separate follow-up). + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + # YAML free entry — ``is_premium=False`` + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + # YAML premium entry — ``is_premium=True`` + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + # OpenRouter dynamic premium entry — same field, same derivation + assert by_id[-20_001]["is_premium"] is True + assert by_id[-20_001]["billing_tier"] == "premium" + + # Every emitted dict (including Auto) must have the field — never missing. + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch): + """When there are no global configs at all, the endpoint emits an + empty list (no Auto stub) — Auto mode would have nothing to route to. + """ + from app.config import config + from app.routes import image_generation_routes + + monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False) + payload = await image_generation_routes.get_global_image_gen_configs(user=None) + assert payload == [] + + +# ============================================================================= +# Vision LLM +# ============================================================================= + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_emit_is_premium(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr( + config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False + ) + + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + + by_id = {c["id"]: c for c in payload} + + assert 0 in by_id, "Auto stub should be emitted when at least one config exists" + assert by_id[0]["is_premium"] is False + assert by_id[0]["billing_tier"] == "free" + + assert by_id[-1]["is_premium"] is False + assert by_id[-1]["billing_tier"] == "free" + + assert by_id[-2]["is_premium"] is True + assert by_id[-2]["billing_tier"] == "premium" + + assert by_id[-30_001]["is_premium"] is True + assert by_id[-30_001]["billing_tier"] == "premium" + + for cfg in payload: + assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}" + assert isinstance(cfg["is_premium"], bool) + + +@pytest.mark.asyncio +async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch): + from app.config import config + from app.routes import vision_llm_routes + + monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False) + payload = await vision_llm_routes.get_global_vision_llm_configs(user=None) + assert payload == [] diff --git a/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py new file mode 100644 index 000000000..b47d9134b --- /dev/null +++ b/surfsense_backend/tests/unit/routes/test_global_new_llm_configs_supports_image.py @@ -0,0 +1,106 @@ +"""Unit tests for ``supports_image_input`` derivation on the chat global +config endpoint (``GET /global-new-llm-configs``). + +Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``): + +1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML + loader for operator overrides, or by the OpenRouter integration from + ``architecture.input_modalities``) — wins. +2. ``derive_supports_image_input`` helper — default-allow on unknown + models, only False when LiteLLM / OR modalities are definitive. + +The flag is purely informational at the API boundary. The streaming +task safety net (``is_known_text_only_chat_model``) is the actual block, +and it requires LiteLLM to *explicitly* mark the model as text-only. +""" + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.unit + + +_FIXTURE: list[dict] = [ + { + "id": -1, + "name": "GPT-4o (explicit true)", + "description": "vision-capable, explicit YAML override", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + "supports_image_input": True, + }, + { + "id": -2, + "name": "DeepSeek V3 (explicit false)", + "description": "OpenRouter dynamic — modality-derived false", + "provider": "OPENROUTER", + "model_name": "deepseek/deepseek-v3.2-exp", + "api_key": "sk-or-test", + "api_base": "https://openrouter.ai/api/v1", + "billing_tier": "free", + "supports_image_input": False, + }, + { + "id": -10_010, + "name": "Unannotated GPT-4o", + "description": "no flag set — resolver should derive True via LiteLLM", + "provider": "OPENAI", + "model_name": "gpt-4o", + "api_key": "sk-test", + "billing_tier": "free", + # supports_image_input intentionally absent + }, + { + "id": -10_011, + "name": "Unannotated unknown model", + "description": "unmapped — default-allow True", + "provider": "CUSTOM", + "custom_provider": "brand_new_proxy", + "model_name": "brand-new-model-x9", + "api_key": "sk-test", + "billing_tier": "free", + }, +] + + +@pytest.mark.asyncio +async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch): + """Each emitted chat config carries ``supports_image_input`` as a + bool. Explicit values win; unannotated entries are resolved via the + helper (default-allow True).""" + from app.config import config + from app.routes import new_llm_config_routes + + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False) + + payload = await new_llm_config_routes.get_global_new_llm_configs(user=None) + by_id = {c["id"]: c for c in payload} + + # Auto stub: optimistic True so the user can keep Auto selected with + # vision-capable deployments somewhere in the pool. + assert 0 in by_id, "Auto stub should be emitted when configs exist" + assert by_id[0]["supports_image_input"] is True + assert by_id[0]["is_auto_mode"] is True + + # Explicit True is preserved. + assert by_id[-1]["supports_image_input"] is True + + # Explicit False is preserved (the exact failure mode the safety net + # guards against — DeepSeek V3 over OpenRouter would 404 with "No + # endpoints found that support image input"). + assert by_id[-2]["supports_image_input"] is False + + # Unannotated GPT-4o: resolver consults LiteLLM, which says vision. + assert by_id[-10_010]["supports_image_input"] is True + + # Unknown / unmapped model: default-allow rather than pre-judge. + assert by_id[-10_011]["supports_image_input"] is True + + for cfg in payload: + assert "supports_image_input" in cfg, ( + f"supports_image_input missing from {cfg.get('id')}" + ) + assert isinstance(cfg["supports_image_input"], bool) diff --git a/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py new file mode 100644 index 000000000..0e19b80e4 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_auto_pin_image_aware.py @@ -0,0 +1,286 @@ +"""Image-aware extension of the Auto-pin resolver. + +When the current chat turn carries an ``image_url`` block, the pin +resolver must: + +1. Filter the candidate pool to vision-capable cfgs so a freshly + selected pin can never be text-only. +2. Treat any existing pin whose capability is False as invalid (force + re-pin), even when it would otherwise be reused as the thread's + stable model. +3. Raise ``ValueError`` (mapped to the friendly + ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming + task) when no vision-capable cfg is available — instead of silently + pinning text-only and 404-ing at the provider. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from app.services.auto_model_pin_service import ( + clear_healthy, + clear_runtime_cooldown, + resolve_or_get_pinned_llm_config_id, +) + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def _reset_caches(): + clear_runtime_cooldown() + clear_healthy() + yield + clear_runtime_cooldown() + clear_healthy() + + +@dataclass +class _FakeQuotaResult: + allowed: bool + + +class _FakeExecResult: + def __init__(self, thread): + self._thread = thread + + def unique(self): + return self + + def scalar_one_or_none(self): + return self._thread + + +class _FakeSession: + def __init__(self, thread): + self.thread = thread + self.commit_count = 0 + + async def execute(self, _stmt): + return _FakeExecResult(self.thread) + + async def commit(self): + self.commit_count += 1 + + +def _thread(*, pinned: int | None = None): + return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned) + + +def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"vision-{id_}", + "api_key": "k", + "billing_tier": tier, + "supports_image_input": True, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict: + return { + "id": id_, + "provider": "OPENAI", + "model_name": f"text-{id_}", + "api_key": "k", + "billing_tier": tier, + # Higher quality than the vision cfgs — so a bug that ignores + # the image flag would surface as the resolver picking this one. + "supports_image_input": False, + "auto_pin_tier": "A", + "quality_score": quality, + } + + +async def _premium_allowed(*_args, **_kwargs): + return _FakeQuotaResult(allowed=True) + + +@pytest.mark.asyncio +async def test_image_turn_filters_out_text_only_candidates(monkeypatch): + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + # The thread should be pinned to the vision cfg even though the + # text-only cfg has a higher quality score. + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch): + """An existing text-only pin must be invalidated when the next turn + requires image input. The non-image path would happily reuse it.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-1)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is False + assert session.thread.pinned_llm_config_id == -2 + + +@pytest.mark.asyncio +async def test_image_turn_reuses_existing_vision_pin(monkeypatch): + """If the thread is already pinned to a vision-capable cfg, reuse it + — same as the non-image path. Image-aware filtering must not force + spurious re-pins.""" + from app.config import config + + session = _FakeSession(_thread(pinned=-2)) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + assert result.resolved_llm_config_id == -2 + assert result.from_existing_pin is True + + +@pytest.mark.asyncio +async def test_image_turn_with_no_vision_candidates_raises(monkeypatch): + """The friendly-error path: no vision-capable cfg in the pool -> raise + ``ValueError`` whose message contains ``vision-capable`` so the + streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1), _text_only_cfg(-2)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + with pytest.raises(ValueError, match="vision-capable"): + await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + + +@pytest.mark.asyncio +async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch): + """Regression guard: the image flag must default False and not affect + a normal text-only turn — text-only cfgs remain selectable.""" + from app.config import config + + session = _FakeSession(_thread()) + monkeypatch.setattr( + config, + "GLOBAL_LLM_CONFIGS", + [_text_only_cfg(-1)], + ) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + ) + assert result.resolved_llm_config_id == -1 + + +@pytest.mark.asyncio +async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch): + """A YAML cfg that omits ``supports_image_input`` falls through to + ``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o`` + that returns True, so the cfg should be a valid candidate.""" + from app.config import config + + session = _FakeSession(_thread()) + cfg_unannotated_vision = { + "id": -2, + "provider": "OPENAI", + "model_name": "gpt-4o", # known vision model in LiteLLM map + "api_key": "k", + "billing_tier": "free", + "auto_pin_tier": "A", + "quality_score": 80, + # NOTE: no supports_image_input key + } + monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision]) + monkeypatch.setattr( + "app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage", + _premium_allowed, + ) + + result = await resolve_or_get_pinned_llm_config_id( + session, + thread_id=1, + search_space_id=10, + user_id=None, + selected_llm_config_id=0, + requires_image_input=True, + ) + assert result.resolved_llm_config_id == -2 diff --git a/surfsense_backend/tests/unit/services/test_billable_call.py b/surfsense_backend/tests/unit/services/test_billable_call.py index 86de5f23d..c820724ed 100644 --- a/surfsense_backend/tests/unit/services/test_billable_call.py +++ b/surfsense_backend/tests/unit/services/test_billable_call.py @@ -15,6 +15,7 @@ vision LLM extraction: from __future__ import annotations +import asyncio import contextlib from typing import Any from uuid import uuid4 @@ -57,6 +58,9 @@ class _FakeSession: async def commit(self) -> None: self.committed = True + async def rollback(self) -> None: + pass + async def close(self) -> None: pass @@ -71,7 +75,9 @@ async def _fake_shielded_session(): _SESSIONS_USED: list[_FakeSession] = [] -def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None): +def _patch_isolation_layer( + monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None +): """Wire fake reserve/finalize/release/session helpers.""" _SESSIONS_USED.clear() reserve_calls: list[dict[str, Any]] = [] @@ -91,6 +97,8 @@ def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None) async def _fake_finalize( *, db_session, user_id, request_id, actual_micros, reserved_micros ): + if finalize_exc is not None: + raise finalize_exc finalize_calls.append( { "user_id": user_id, @@ -343,6 +351,125 @@ async def test_premium_uses_estimator_when_no_micros_override(monkeypatch): assert spies["reserve"][0]["reserve_micros"] == 12_345 +@pytest.mark.asyncio +async def test_premium_finalize_failure_propagates_and_releases(monkeypatch): + from app.services.billable_calls import BillingSettlementError, billable_call + + class _FinalizeError(RuntimeError): + pass + + spies = _patch_isolation_layer( + monkeypatch, + reserve_result=_FakeQuotaResult(allowed=True), + finalize_exc=_FinalizeError("db finalize failed"), + ) + user_id = uuid4() + + with pytest.raises(BillingSettlementError): + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["release"]) == 1 + assert spies["record"] == [] + + +@pytest.mark.asyncio +async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + user_id = uuid4() + + class _HangingCommitSession(_FakeSession): + async def commit(self) -> None: + await asyncio.sleep(60) + + @contextlib.asynccontextmanager + async def _hanging_session_factory(): + s = _HangingCommitSession() + _SESSIONS_USED.append(s) + yield s + + async with billable_call( + user_id=user_id, + search_space_id=42, + billing_tier="premium", + base_model="openai/gpt-image-1", + quota_reserve_micros_override=50_000, + usage_type="image_generation", + billable_session_factory=_hanging_session_factory, + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=40_000, + call_kind="image_generation", + ) + + assert len(spies["reserve"]) == 1 + assert len(spies["finalize"]) == 1 + assert len(spies["record"]) == 1 + assert spies["release"] == [] + + +@pytest.mark.asyncio +async def test_free_audit_failure_is_best_effort(monkeypatch): + from app.services.billable_calls import billable_call + + spies = _patch_isolation_layer( + monkeypatch, reserve_result=_FakeQuotaResult(allowed=True) + ) + + async def _failing_record(_session, **_kwargs): + raise RuntimeError("audit insert failed") + + monkeypatch.setattr( + "app.services.billable_calls.record_token_usage", + _failing_record, + raising=False, + ) + + async with billable_call( + user_id=uuid4(), + search_space_id=42, + billing_tier="free", + base_model="openai/gpt-image-1", + usage_type="image_generation", + audit_timeout_seconds=0.01, + ) as acc: + acc.add( + model="openai/gpt-image-1", + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + cost_micros=37_000, + call_kind="image_generation", + ) + + assert spies["reserve"] == [] + assert spies["finalize"] == [] + + # --------------------------------------------------------------------------- # Podcast / video-presentation usage_type coverage # --------------------------------------------------------------------------- @@ -387,7 +514,7 @@ async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch): assert len(spies["record"]) == 1 row = spies["record"][0] assert row["usage_type"] == "podcast_generation" - assert row["thread_id"] == 99 + assert row["thread_id"] is None assert row["search_space_id"] == 42 assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py new file mode 100644 index 000000000..9d5fdb190 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -0,0 +1,177 @@ +"""Defense-in-depth: image-gen call sites must not let an empty +``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``. + +The bug repro: an OpenRouter image-gen config ships +``api_base=""``. The pre-fix call site in +``image_generation_routes._execute_image_generation`` did +``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which +silently dropped the empty string. LiteLLM then fell back to +``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``) +and OpenRouter's ``image_generation/transformation`` appended +``/chat/completions`` to it → 404 ``Resource not found``. + +This test pins the post-fix behaviour: with an empty ``api_base`` in +the config, the call site MUST set ``api_base`` to OpenRouter's public +URL instead of leaving it unset. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): + """The global-config branch (``config_id < 0``) of + ``_execute_image_generation`` must apply the resolver and pin + ``api_base`` to OpenRouter when the config ships an empty string. + """ + from app.routes import image_generation_routes + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", # the original bug shape + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={}) + + image_gen = MagicMock() + image_gen.image_generation_config_id = cfg["id"] + image_gen.prompt = "test" + image_gen.n = 1 + image_gen.quality = None + image_gen.size = None + image_gen.style = None + image_gen.response_format = None + image_gen.model = None + + search_space = MagicMock() + search_space.image_generation_config_id = cfg["id"] + session = MagicMock() + + with ( + patch.object( + image_generation_routes, + "_get_global_image_gen_config", + return_value=cfg, + ), + patch.object( + image_generation_routes, + "aimage_generation", + side_effect=fake_aimage_generation, + ), + ): + await image_generation_routes._execute_image_generation( + session=session, image_gen=image_gen, search_space=search_space + ) + + # The whole point of the fix: even with empty ``api_base`` in the + # config, we forward OpenRouter's public URL so the call doesn't + # inherit an Azure endpoint. + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +@pytest.mark.asyncio +async def test_generate_image_tool_global_sets_api_base_when_config_empty(): + """Same defense at the agent tool entry point — both surfaces share + the same OpenRouter config payloads.""" + from app.agents.new_chat.tools import generate_image as gi_module + + cfg = { + "id": -20_001, + "name": "GPT Image 1 (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-image-1", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + } + + captured: dict = {} + + async def fake_aimage_generation(**kwargs): + captured.update(kwargs) + response = MagicMock() + response.model_dump.return_value = { + "data": [{"url": "https://example.com/x.png"}] + } + response._hidden_params = {"model": "openrouter/openai/gpt-image-1"} + return response + + search_space = MagicMock() + search_space.id = 1 + search_space.image_generation_config_id = cfg["id"] + + session_cm = AsyncMock() + session = AsyncMock() + session_cm.__aenter__.return_value = session + + scalars = MagicMock() + scalars.first.return_value = search_space + exec_result = MagicMock() + exec_result.scalars.return_value = scalars + session.execute.return_value = exec_result + session.add = MagicMock() + session.commit = AsyncMock() + session.refresh = AsyncMock() + + # ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback. + async def _refresh(obj): + obj.id = 1 + + session.refresh.side_effect = _refresh + + with ( + patch.object(gi_module, "shielded_async_session", return_value=session_cm), + patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg), + patch.object( + gi_module, "aimage_generation", side_effect=fake_aimage_generation + ), + patch.object( + gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0 + ), + ): + tool = gi_module.create_generate_image_tool( + search_space_id=1, db_session=MagicMock() + ) + await tool.ainvoke({"prompt": "a cat", "n": 1}) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-image-1" + + +def test_image_gen_router_deployment_sets_api_base_when_config_empty(): + """The Auto-mode router pool must also resolve ``api_base`` when an + OpenRouter config ships an empty string. The deployment dict is fed + straight to ``litellm.Router``, so a missing ``api_base`` would + leak the same way as the direct call sites. + """ + from app.services.image_gen_router_service import ImageGenRouterService + + deployment = ImageGenRouterService._config_to_deployment( + { + "model_name": "openai/gpt-image-1", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1" diff --git a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py index b635b4fe8..88fcf2db3 100644 --- a/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py +++ b/surfsense_backend/tests/unit/services/test_openrouter_integration_service.py @@ -265,6 +265,10 @@ def test_generate_image_gen_configs_filters_by_image_output(): assert c["billing_tier"] in {"free", "premium"} assert c["provider"] == "OPENROUTER" assert c[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't 404 against an inherited Azure endpoint. + assert c["api_base"] == "https://openrouter.ai/api/v1" def test_generate_image_gen_configs_assigns_image_id_offset(): @@ -342,6 +346,10 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output(): assert cfg["input_cost_per_token"] == pytest.approx(5e-6) assert cfg["output_cost_per_token"] == pytest.approx(15e-6) assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True + # Defense-in-depth: emit the OpenRouter base URL at source so a + # downstream call site that forgets ``resolve_api_base`` still + # doesn't inherit an Azure endpoint. + assert cfg["api_base"] == "https://openrouter.ai/api/v1" def test_generate_vision_llm_configs_drops_chat_only_filters(): diff --git a/surfsense_backend/tests/unit/services/test_provider_api_base.py b/surfsense_backend/tests/unit/services/test_provider_api_base.py new file mode 100644 index 000000000..12cd0a3d5 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_api_base.py @@ -0,0 +1,107 @@ +"""Unit tests for the shared ``api_base`` resolver. + +The cascade exists so vision and image-gen call sites can't silently +inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``) +when an OpenRouter / Groq / etc. config ships an empty string. See +``provider_api_base`` module docstring for the original repro +(OpenRouter image-gen 404-ing against an Azure endpoint). +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_api_base import ( + PROVIDER_DEFAULT_API_BASE, + PROVIDER_KEY_DEFAULT_API_BASE, + resolve_api_base, +) + +pytestmark = pytest.mark.unit + + +def test_config_value_wins_over_defaults(): + """A non-empty config value is always returned verbatim, even when the + provider has a default — the operator gets the last word.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="https://my-openrouter-mirror.example.com/v1", + ) + assert result == "https://my-openrouter-mirror.example.com/v1" + + +def test_provider_key_default_when_config_missing(): + """``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own + base URL — the provider-key map must take precedence over the prefix + map so DeepSeek requests don't go to OpenAI.""" + result = resolve_api_base( + provider="DEEPSEEK", + provider_prefix="openai", + config_api_base=None, + ) + assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_provider_prefix_default_when_no_key_default(): + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=None, + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_unknown_provider_returns_none(): + """When neither map matches we return ``None`` so the caller can let + LiteLLM apply its own provider-integration default (Azure deployment + URL, custom-provider URL, etc.).""" + result = resolve_api_base( + provider="SOMETHING_NEW", + provider_prefix="something_new", + config_api_base=None, + ) + assert result is None + + +def test_empty_string_config_treated_as_missing(): + """The original bug: OpenRouter dynamic configs ship ``api_base=""`` + and downstream call sites use ``if cfg.get("api_base"):`` — empty + strings are falsy in Python but the cascade has to step in anyway.""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base="", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_whitespace_only_config_treated_as_missing(): + """A config value of ``" "`` is a configuration mistake — treat it + as missing instead of forwarding whitespace to LiteLLM (which would + almost certainly 404).""" + result = resolve_api_base( + provider="OPENROUTER", + provider_prefix="openrouter", + config_api_base=" ", + ) + assert result == PROVIDER_DEFAULT_API_BASE["openrouter"] + + +def test_provider_case_insensitive(): + """Some call sites pass the provider lowercase (DB enum value), others + uppercase (YAML key). Both must resolve.""" + upper = resolve_api_base( + provider="DEEPSEEK", provider_prefix="openai", config_api_base=None + ) + lower = resolve_api_base( + provider="deepseek", provider_prefix="openai", config_api_base=None + ) + assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"] + + +def test_all_inputs_none_returns_none(): + assert ( + resolve_api_base(provider=None, provider_prefix=None, config_api_base=None) + is None + ) diff --git a/surfsense_backend/tests/unit/services/test_provider_capabilities.py b/surfsense_backend/tests/unit/services/test_provider_capabilities.py new file mode 100644 index 000000000..aac88977f --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_provider_capabilities.py @@ -0,0 +1,244 @@ +"""Unit tests for the shared chat-image capability resolver. + +Two resolvers, two intents: + +- ``derive_supports_image_input`` — best-effort True for the catalog and + selector. Default-allow on unknown / unmapped models. The streaming + task safety net never sees this value directly. + +- ``is_known_text_only_chat_model`` — strict opt-out for the safety net. + Returns True only when LiteLLM's model map *explicitly* sets + ``supports_vision=False``. Anything else (missing key, exception, + True) returns False so the request flows through to the provider. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import ( + derive_supports_image_input, + is_known_text_only_chat_model, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — OpenRouter modalities path (authoritative) +# --------------------------------------------------------------------------- + + +def test_or_modalities_with_image_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="openai/gpt-4o", + openrouter_input_modalities=["text", "image"], + ) + is True + ) + + +def test_or_modalities_text_only_returns_false(): + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="deepseek/deepseek-v3.2-exp", + openrouter_input_modalities=["text"], + ) + is False + ) + + +def test_or_modalities_empty_list_returns_false(): + """OR explicitly publishing an empty modality list is a definitive + 'no inputs at all' signal — treat as False rather than falling back + to LiteLLM.""" + assert ( + derive_supports_image_input( + provider="OPENROUTER", + model_name="weird/empty-modalities", + openrouter_input_modalities=[], + ) + is False + ) + + +def test_or_modalities_none_falls_through_to_litellm(): + """``None`` (missing key) is *not* a definitive signal — fall through + to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map.""" + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + openrouter_input_modalities=None, + ) + is True + ) + + +# --------------------------------------------------------------------------- +# derive_supports_image_input — LiteLLM model-map path +# --------------------------------------------------------------------------- + + +def test_litellm_known_vision_model_returns_true(): + assert ( + derive_supports_image_input( + provider="OPENAI", + model_name="gpt-4o", + ) + is True + ) + + +def test_litellm_base_model_wins_over_model_name(): + """Azure-style entries pass model_name=deployment_id and put the + canonical sku in litellm_params.base_model. The resolver must + consult base_model first or the deployment id (which LiteLLM + doesn't know) would shadow the real capability.""" + assert ( + derive_supports_image_input( + provider="AZURE_OPENAI", + model_name="my-azure-deployment-id", + base_model="gpt-4o", + ) + is True + ) + + +def test_litellm_unknown_model_default_allows(): + """Default-allow on unknown — the safety net is the actual block.""" + assert ( + derive_supports_image_input( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is True + ) + + +def test_litellm_known_text_only_returns_false(): + """A model that LiteLLM explicitly knows is text-only resolves to + False even via the catalog resolver. ``deepseek-chat`` (the + DeepSeek-V3 chat sku) is in the map without supports_vision and + LiteLLM's `supports_vision` returns False.""" + # Sanity: confirm the helper's negative path. We use a small model + # known not to support vision per the map. + result = derive_supports_image_input( + provider="DEEPSEEK", + model_name="deepseek-chat", + ) + # We accept either False (LiteLLM said explicit no) or True + # (default-allow if the entry isn't mapped on this version) — the + # invariant is that the resolver never *raises* on a known-text-only + # provider/model. The behaviour-binding assertion lives in + # ``test_is_known_text_only_chat_model_explicit_false`` below. + assert isinstance(result, bool) + + +# --------------------------------------------------------------------------- +# is_known_text_only_chat_model — strict opt-out semantics +# --------------------------------------------------------------------------- + + +def test_is_known_text_only_returns_false_for_vision_model(): + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_false_for_unknown_model(): + """Strict opt-out: missing from the map ≠ text-only. The safety net + must NOT fire for an unmapped model — that's the regression we're + fixing.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + model_name="brand-new-model-x9-unmapped", + custom_provider="brand_new_proxy", + ) + is False + ) + + +def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch): + """LiteLLM's ``get_model_info`` raises freely on parse errors. The + helper swallows the exception and returns False so the safety net + doesn't fire on a transient lookup failure.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise ValueError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch): + """Stub LiteLLM's ``get_model_info`` to return an explicit False so + we exercise the opt-out path deterministically. Using a stub keeps + the test stable across LiteLLM map updates.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is True + ) + + +def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch): + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) + + +def test_is_known_text_only_returns_false_on_missing_key(monkeypatch): + """A model entry without ``supports_vision`` at all is treated as + 'unknown' — strict opt-out means False.""" + import app.services.provider_capabilities as pc + + def _info(**_kwargs): + return {"max_input_tokens": 8192} # no supports_vision + + monkeypatch.setattr(pc.litellm, "get_model_info", _info) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="any-model", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/services/test_supports_image_input.py b/surfsense_backend/tests/unit/services/test_supports_image_input.py new file mode 100644 index 000000000..71fdee1c7 --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_supports_image_input.py @@ -0,0 +1,281 @@ +"""Unit tests for the chat-catalog ``supports_image_input`` capability flag. + +Capability is sourced from two places, in order of preference: + +1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs + (authoritative — OpenRouter publishes per-model modalities directly). +2. LiteLLM's authoritative model map (``litellm.supports_vision``) for + YAML / BYOK configs that don't carry an explicit operator override. + +The catalog default is *True* (conservative-allow): an unknown / unmapped +model is not pre-judged. The streaming-task safety net +(``is_known_text_only_chat_model``) is the only place a False actually +blocks a request — and it requires LiteLLM to *explicitly* mark the model +as text-only. +""" + +from __future__ import annotations + +import pytest + +from app.services.openrouter_integration_service import ( + _OPENROUTER_DYNAMIC_MARKER, + _generate_configs, + _supports_image_input, +) + +pytestmark = pytest.mark.unit + + +_SETTINGS_BASE: dict = { + "api_key": "sk-or-test", + "id_offset": -10_000, + "rpm": 200, + "tpm": 1_000_000, + "free_rpm": 20, + "free_tpm": 100_000, + "anonymous_enabled_paid": False, + "anonymous_enabled_free": True, + "quota_reserve_tokens": 4000, +} + + +# --------------------------------------------------------------------------- +# _supports_image_input helper (OpenRouter modalities) +# --------------------------------------------------------------------------- + + +def test_supports_image_input_true_for_multimodal(): + assert ( + _supports_image_input( + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + } + ) + is True + ) + + +def test_supports_image_input_false_for_text_only(): + """The exact failure mode the safety net guards against — DeepSeek V3 + is a text-in/text-out model and would 404 if forwarded image_url.""" + assert ( + _supports_image_input( + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + } + ) + is False + ) + + +def test_supports_image_input_false_when_modalities_missing(): + """Defensive: missing architecture is treated as text-only at the + OpenRouter helper level. The wider catalog resolver + (`derive_supports_image_input`) only consults modalities when they + are non-empty, otherwise it falls back to LiteLLM.""" + assert _supports_image_input({"id": "weird/model"}) is False + assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False + assert ( + _supports_image_input( + {"id": "weird/model", "architecture": {"input_modalities": None}} + ) + is False + ) + + +# --------------------------------------------------------------------------- +# _generate_configs threads the flag onto every emitted chat config +# --------------------------------------------------------------------------- + + +def test_generate_configs_emits_supports_image_input(): + raw = [ + { + "id": "openai/gpt-4o", + "architecture": { + "input_modalities": ["text", "image"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000005", "completion": "0.000015"}, + }, + { + "id": "deepseek/deepseek-v3.2-exp", + "architecture": { + "input_modalities": ["text"], + "output_modalities": ["text"], + }, + "supported_parameters": ["tools"], + "context_length": 200_000, + "pricing": {"prompt": "0.000003", "completion": "0.000015"}, + }, + ] + cfgs = _generate_configs(raw, dict(_SETTINGS_BASE)) + by_model = {c["model_name"]: c for c in cfgs} + + gpt = by_model["openai/gpt-4o"] + assert gpt["supports_image_input"] is True + assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True + + deepseek = by_model["deepseek/deepseek-v3.2-exp"] + assert deepseek["supports_image_input"] is False + assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True + + +# --------------------------------------------------------------------------- +# YAML loader: defer to derive_supports_image_input on unannotated entries +# --------------------------------------------------------------------------- + + +def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch): + """The regression case: an Azure GPT-5.x YAML entry without a + ``supports_image_input`` override should resolve to True via LiteLLM's + model map (which says ``supports_vision: true``). Previously this + defaulted to False, blocking every image turn for vision-capable + YAML configs.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -2 + name: Azure GPT-4o + provider: AZURE_OPENAI + model_name: gpt-4o + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch): + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: GPT-4o + provider: OPENAI + model_name: gpt-4o + api_key: sk-test + supports_image_input: false +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + # Operator override always wins, even against LiteLLM's True. + assert configs[0]["supports_image_input"] is False + + +def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch): + """Unknown / unmapped model in YAML: default-allow. The streaming + safety net (which requires an explicit-False from LiteLLM) is the + only place a real block happens, so we don't lock the user out of + a freshly added third-party entry the catalog can't introspect.""" + yaml_dir = tmp_path / "app" / "config" + yaml_dir.mkdir(parents=True) + (yaml_dir / "global_llm_config.yaml").write_text( + """ +global_llm_configs: + - id: -1 + name: Some Brand New Model + provider: CUSTOM + custom_provider: brand_new_proxy + model_name: brand-new-model-x9 + api_key: sk-test +""", + encoding="utf-8", + ) + + from app import config as config_module + + monkeypatch.setattr(config_module, "BASE_DIR", tmp_path) + + configs = config_module.load_global_llm_configs() + assert len(configs) == 1 + assert configs[0]["supports_image_input"] is True + + +# --------------------------------------------------------------------------- +# AgentConfig threads the flag through both YAML and Auto / BYOK +# --------------------------------------------------------------------------- + + +def test_agent_config_from_yaml_explicit_overrides_resolver(): + from app.agents.new_chat.llm_config import AgentConfig + + cfg_text_only = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "Text Only Override", + "provider": "openai", + "model_name": "gpt-4o", # Capable per LiteLLM, but operator says no. + "api_key": "sk-test", + "supports_image_input": False, + } + ) + cfg_explicit_vision = AgentConfig.from_yaml_config( + { + "id": -2, + "name": "GPT-4o", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + "supports_image_input": True, + } + ) + assert cfg_text_only.supports_image_input is False + assert cfg_explicit_vision.supports_image_input is True + + +def test_agent_config_from_yaml_unannotated_uses_resolver(): + """Without an explicit YAML key, AgentConfig defers to the catalog + resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True.""" + from app.agents.new_chat.llm_config import AgentConfig + + cfg = AgentConfig.from_yaml_config( + { + "id": -1, + "name": "GPT-4o (no override)", + "provider": "openai", + "model_name": "gpt-4o", + "api_key": "sk-test", + } + ) + assert cfg.supports_image_input is True + + +def test_agent_config_auto_mode_supports_image_input(): + """Auto routes across the pool. We optimistically allow image input + so users can keep their selection on Auto with a vision-capable + deployment somewhere in the pool. The router's own `allowed_fails` + handles non-vision deployments via fallback.""" + from app.agents.new_chat.llm_config import AgentConfig + + auto = AgentConfig.from_auto_mode() + assert auto.supports_image_input is True diff --git a/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py new file mode 100644 index 000000000..b8ba9d80c --- /dev/null +++ b/surfsense_backend/tests/unit/services/test_vision_llm_api_base_defense.py @@ -0,0 +1,89 @@ +"""Defense-in-depth: vision-LLM resolution must not leak ``api_base`` +defaults from ``litellm.api_base`` either. + +Vision shares the same shape as image-gen — global YAML / OpenRouter +dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm`` +call sites would silently drop the empty string and inherit +``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on +construction so we test the kwargs we hand to it instead. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +pytestmark = pytest.mark.unit + + +@pytest.mark.asyncio +async def test_get_vision_llm_global_openrouter_sets_api_base(): + """Global negative-ID branch: an OpenRouter vision config with + ``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with + ``api_base="https://openrouter.ai/api/v1"`` — never an empty string, + never silently absent.""" + from app.services import llm_service + + cfg = { + "id": -30_001, + "name": "GPT-4o Vision (OpenRouter)", + "provider": "OPENROUTER", + "model_name": "openai/gpt-4o", + "api_key": "sk-or-test", + "api_base": "", + "api_version": None, + "litellm_params": {}, + "billing_tier": "free", + } + + search_space = MagicMock() + search_space.id = 1 + search_space.user_id = "user-x" + search_space.vision_llm_config_id = cfg["id"] + + session = AsyncMock() + scalars = MagicMock() + scalars.first.return_value = search_space + result = MagicMock() + result.scalars.return_value = scalars + session.execute.return_value = result + + captured: dict = {} + + class FakeSanitized: + def __init__(self, **kwargs): + captured.update(kwargs) + + with ( + patch( + "app.services.vision_llm_router_service.get_global_vision_llm_config", + return_value=cfg, + ), + patch( + "app.agents.new_chat.llm_config.SanitizedChatLiteLLM", + new=FakeSanitized, + ), + ): + await llm_service.get_vision_llm(session=session, search_space_id=1) + + assert captured.get("api_base") == "https://openrouter.ai/api/v1" + assert captured["model"] == "openrouter/openai/gpt-4o" + + +def test_vision_router_deployment_sets_api_base_when_config_empty(): + """Auto-mode vision router: deployments are fed to ``litellm.Router``, + so the resolver has to apply at deployment construction time too.""" + from app.services.vision_llm_router_service import VisionLLMRouterService + + deployment = VisionLLMRouterService._config_to_deployment( + { + "model_name": "openai/gpt-4o", + "provider": "OPENROUTER", + "api_key": "sk-or-test", + "api_base": "", + } + ) + assert deployment is not None + assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1" + assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o" diff --git a/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py new file mode 100644 index 000000000..a5bb3f58a --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_celery_async_runner.py @@ -0,0 +1,318 @@ +"""Regression tests for ``run_async_celery_task``. + +These tests pin down the production bug observed on 2026-05-02 where +the video-presentation Celery task hung at ``[billable_call] finalize`` +because the shared ``app.db.engine`` had pooled asyncpg connections +bound to a *previous* task's now-closed event loop. Reusing such a +connection on a fresh loop crashes inside ``pool_pre_ping`` with:: + + AttributeError: 'NoneType' object has no attribute 'send' + +(the proactor is None because the loop is gone) and can hang forever +inside the asyncpg ``Connection._cancel`` cleanup coroutine. + +The fix is ``run_async_celery_task``: a small helper that runs every +async celery task body inside a fresh event loop and disposes the +shared engine pool both before (defends against a previous task that +crashed) and after (releases connections we opened on this loop). + +Tests here exercise the helper with a stub engine that records +``dispose()`` calls and panics if a coroutine produced by one loop is +awaited on another — mirroring the real asyncpg behaviour. +""" + +from __future__ import annotations + +import asyncio +import gc +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from unittest.mock import patch + +import pytest + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Stub engine that emulates the asyncpg-on-stale-loop crash +# --------------------------------------------------------------------------- + + +class _StaleLoopEngine: + """Tiny stand-in for ``app.db.engine`` that tracks dispose() calls. + + ``dispose()`` is async (matches ``AsyncEngine.dispose``) and records + the running event loop id so tests can assert it ran on *each* + fresh loop. + """ + + def __init__(self) -> None: + self.dispose_loop_ids: list[int] = [] + + async def dispose(self) -> None: + loop = asyncio.get_running_loop() + self.dispose_loop_ids.append(id(loop)) + + +@contextmanager +def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]: + """Patch ``from app.db import engine as shared_engine`` lookup. + + The helper imports lazily inside the function body, so we have to + patch the attribute on the already-loaded ``app.db`` module. + """ + import app.db as app_db + + original = getattr(app_db, "engine", None) + app_db.engine = stub # type: ignore[attr-defined] + try: + yield + finally: + if original is None: + with pytest.raises(AttributeError): + _ = app_db.engine + else: + app_db.engine = original # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_runner_returns_value_and_disposes_engine_around_call() -> None: + """Happy path: the coroutine result is returned, and the shared + engine is disposed both before and after the task body runs. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> str: + # Engine should already have been disposed once before we run. + assert len(stub.dispose_loop_ids) == 1 + return "ok" + + with _patch_shared_engine(stub): + result = run_async_celery_task(_body) + + assert result == "ok" + # Once before the body, once after (in finally). + assert len(stub.dispose_loop_ids) == 2 + # Both disposes ran on the SAME (fresh) loop the task body used. + assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1] + + +def test_runner_creates_fresh_loop_per_invocation() -> None: + """Each call must spin its own loop. Without this guarantee a + previous task's loop would be reused and the asyncpg-stale-loop + crash would never be avoided. + """ + import app.tasks.celery_tasks as celery_tasks_pkg + + stub = _StaleLoopEngine() + new_loop_calls = 0 + closed_loops: list[bool] = [] + + real_new_event_loop = asyncio.new_event_loop + + def _counting_new_loop() -> asyncio.AbstractEventLoop: + nonlocal new_loop_calls + new_loop_calls += 1 + loop = real_new_event_loop() + # Hook close() so we can verify each loop was closed properly + # before the next one was created. + original_close = loop.close + + def _tracked_close() -> None: + closed_loops.append(True) + original_close() + + loop.close = _tracked_close # type: ignore[method-assign] + return loop + + async def _body() -> None: + # Loop is alive and current at body execution time. + running = asyncio.get_running_loop() + assert not running.is_closed() + + with ( + _patch_shared_engine(stub), + patch.object(asyncio, "new_event_loop", _counting_new_loop), + ): + for _ in range(3): + celery_tasks_pkg.run_async_celery_task(_body) + + assert new_loop_calls == 3 + assert closed_loops == [True, True, True] + # Each invocation disposed twice (before + after). + assert len(stub.dispose_loop_ids) == 6 + + +def test_runner_disposes_engine_even_when_body_raises() -> None: + """Cleanup MUST run on the failure path too — otherwise stale + connections leak into the next task and cause the original hang. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + class _BoomError(RuntimeError): + pass + + async def _body() -> None: + raise _BoomError("kaboom") + + with _patch_shared_engine(stub), pytest.raises(_BoomError): + run_async_celery_task(_body) + + assert len(stub.dispose_loop_ids) == 2 # before + after still ran + + +def test_runner_swallows_dispose_errors() -> None: + """A flaky engine.dispose() must NEVER take down a celery task. + + Production scenario: the very first dispose (before the body runs) + might hit a partially-initialised engine; the helper logs and + moves on. The task body still runs; the result is still returned. + """ + from app.tasks.celery_tasks import run_async_celery_task + + class _AngryEngine: + def __init__(self) -> None: + self.calls = 0 + + async def dispose(self) -> None: + self.calls += 1 + raise RuntimeError("dispose() blew up") + + stub = _AngryEngine() + + async def _body() -> int: + return 42 + + with _patch_shared_engine(stub): + assert run_async_celery_task(_body) == 42 + + assert stub.calls == 2 # before + after both attempted + + +def test_runner_propagates_value_from_async_body() -> None: + """Sanity: pass-through of any pickleable celery return value.""" + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _body() -> dict[str, object]: + return {"status": "ready", "video_presentation_id": 19} + + with _patch_shared_engine(stub): + out = run_async_celery_task(_body) + + assert out == {"status": "ready", "video_presentation_id": 19} + + +def test_video_presentation_task_uses_runner_helper() -> None: + """Defence-in-depth: confirm the celery task module imports + ``run_async_celery_task``. If a future refactor inlines a + ``loop = asyncio.new_event_loop(); ... loop.close()`` block again, + the original hang will return. + """ + # The module's task body should not contain a manual new_event_loop + # call — that's exactly what the helper exists to centralise. + import inspect + + from app.tasks.celery_tasks import video_presentation_tasks + + src = inspect.getsource(video_presentation_tasks) + assert "run_async_celery_task" in src, ( + "video_presentation_tasks.py must use run_async_celery_task; " + "manual asyncio.new_event_loop() in a celery task hangs on the " + "shared SQLAlchemy pool when reused across tasks." + ) + assert "asyncio.new_event_loop" not in src, ( + "video_presentation_tasks.py contains a raw asyncio.new_event_loop " + "call — route every async task through run_async_celery_task to " + "avoid the stale-pool hang." + ) + + +def test_podcast_task_uses_runner_helper() -> None: + """Symmetric assertion for the podcast task — same root cause, same + fix, same regression risk. + """ + import inspect + + from app.tasks.celery_tasks import podcast_tasks + + src = inspect.getsource(podcast_tasks) + assert "run_async_celery_task" in src + assert "asyncio.new_event_loop" not in src + + +def test_runner_runs_shutdown_asyncgens_before_close() -> None: + """If the task body created any async generators that didn't get + fully iterated, we must still call ``loop.shutdown_asyncgens()`` + before closing — otherwise we leak event-loop bound resources + that re-emerge as ``RuntimeError: Event loop is closed`` later. + """ + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + async def _agen(): + try: + yield 1 + yield 2 + finally: + pass + + async def _body() -> None: + # Iterate the agen partially, then leave it dangling — exactly + # the situation shutdown_asyncgens() is designed to clean up. + async for v in _agen(): + if v == 1: + break + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # By the time the helper returns, garbage collection + shutdown_asyncgens + # should have ensured no live async-gen references remain. We don't + # assert agen.closed directly (it depends on GC ordering); the real + # contract is "no warnings, no event-loop-closed errors". A successful + # second invocation proves the loop was cleaned up properly. + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + # Force a GC pass to surface any 'coroutine was never awaited' + # warnings that would indicate the cleanup is broken. + gc.collect() + + +def test_runner_uses_proactor_loop_on_windows() -> None: + """On Windows the celery worker preselects a Proactor policy so + subprocess (ffmpeg) calls work. The helper must not silently fall + back to a Selector loop and re-break video/podcast generation. + """ + if not sys.platform.startswith("win"): + pytest.skip("Windows-specific event-loop policy assertion") + + from app.tasks.celery_tasks import run_async_celery_task + + stub = _StaleLoopEngine() + + # Mirror the policy set at the top of every Windows celery task. + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + observed: list[str] = [] + + async def _body() -> None: + observed.append(type(asyncio.get_running_loop()).__name__) + + with _patch_shared_engine(stub): + run_async_celery_task(_body) + + assert observed == ["ProactorEventLoop"] diff --git a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py index 38d6ba2ca..699297df1 100644 --- a/surfsense_backend/tests/unit/tasks/test_podcast_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_podcast_billing.py @@ -113,6 +113,19 @@ async def _denying_billable_call(**kwargs): yield SimpleNamespace() # pragma: no cover — for grammar only +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -187,8 +200,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp call["quota_reserve_micros_override"] == app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS ) - assert call["thread_id"] == 99 - assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"} + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call + assert call["call_details"] == { + "podcast_id": 7, + "title": "Test Podcast", + "thread_id": 99, + } + assert callable(call["billable_session_factory"]) @pytest.mark.asyncio @@ -279,6 +299,49 @@ async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypat assert graph_invoked == [] # Graph never ran on denied reservation. +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch): + from app.db import PodcastStatus + from app.tasks.celery_tasks import podcast_tasks + + podcast = _make_podcast(podcast_id=10) + session = _FakeSession(podcast) + monkeypatch.setattr( + podcast_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver + ) + monkeypatch.setattr( + podcast_tasks, "billable_call", _settlement_failing_billable_call + ) + + async def _fake_graph_invoke(state, config): + return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"} + + monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke) + + result = await podcast_tasks._generate_content_podcast( + podcast_id=10, + source_content="hi", + search_space_id=555, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "podcast_id": 10, + "reason": "billing_settlement_failed", + } + assert podcast.status == PodcastStatus.FAILED + + @pytest.mark.asyncio async def test_resolver_failure_marks_podcast_failed(monkeypatch): """If the resolver raises (e.g. search-space deleted), the task fails diff --git a/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py new file mode 100644 index 000000000..792d059b0 --- /dev/null +++ b/surfsense_backend/tests/unit/tasks/test_stream_new_chat_image_safety_net.py @@ -0,0 +1,119 @@ +"""Predicate-level test for the chat streaming safety net. + +The safety net in ``stream_new_chat`` rejects an image turn early with +a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the +selected model is *known* to be text-only. The earlier round of this +work used a strict opt-in flag (``supports_image_input`` defaulting to +False on every YAML entry) which blocked vision-capable Azure GPT-5.x +deployments — this is the regression we're fixing. + +The new predicate is :func:`is_known_text_only_chat_model`, which +returns True only when LiteLLM's authoritative model map *explicitly* +sets ``supports_vision=False``. Anything else (vision True, missing +key, exception) returns False so the request flows through to the +provider. + +We exercise the predicate directly here rather than driving the full +``stream_new_chat`` generator — covering the gate in isolation keeps +the test focused on the regression while the generator's wider behavior +is exercised by the integration suite. +""" + +from __future__ import annotations + +import pytest + +from app.services.provider_capabilities import is_known_text_only_chat_model + +pytestmark = pytest.mark.unit + + +def test_safety_net_does_not_fire_for_azure_gpt_4o(): + """Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is + vision-capable per LiteLLM's model map. The previous round's + blanket-False default blocked it; the new predicate must NOT mark + it text-only.""" + assert ( + is_known_text_only_chat_model( + provider="AZURE_OPENAI", + model_name="my-azure-deployment", + base_model="gpt-4o", + ) + is False + ) + + +def test_safety_net_does_not_fire_for_unknown_model(): + """Default-pass on unknown — the safety net only blocks definitive + text-only confirmations. A freshly added third-party model that + LiteLLM doesn't know about must flow through to the provider.""" + assert ( + is_known_text_only_chat_model( + provider="CUSTOM", + custom_provider="brand_new_proxy", + model_name="brand-new-model-x9", + ) + is False + ) + + +def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch): + """Transient ``litellm.get_model_info`` exception ≠ block. The + helper swallows the error and treats it as 'unknown' → False.""" + import app.services.provider_capabilities as pc + + def _raise(**_kwargs): + raise RuntimeError("intentional test failure") + + monkeypatch.setattr(pc.litellm, "get_model_info", _raise) + + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="gpt-4o", + ) + is False + ) + + +def test_safety_net_fires_only_on_explicit_false(monkeypatch): + """Stub LiteLLM to assert the only path that returns True is the + explicit ``supports_vision=False`` case. Anything else (True, + None, missing key) returns False from the predicate.""" + import app.services.provider_capabilities as pc + + def _info_explicit_false(**_kwargs): + return {"supports_vision": False, "max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="text-only-stub", + ) + is True + ) + + def _info_true(**_kwargs): + return {"supports_vision": True} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_true) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="vision-stub", + ) + is False + ) + + def _info_missing(**_kwargs): + return {"max_input_tokens": 8192} + + monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing) + assert ( + is_known_text_only_chat_model( + provider="OPENAI", + model_name="missing-key-stub", + ) + is False + ) diff --git a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py index 671f57ae4..423b64ddb 100644 --- a/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py +++ b/surfsense_backend/tests/unit/tasks/test_video_presentation_billing.py @@ -105,6 +105,19 @@ async def _denying_billable_call(**kwargs): yield SimpleNamespace() # pragma: no cover +@contextlib.asynccontextmanager +async def _settlement_failing_billable_call(**kwargs): + from app.services.billable_calls import BillingSettlementError + + _CALL_LOG.append(kwargs) + yield SimpleNamespace() + raise BillingSettlementError( + usage_type=kwargs.get("usage_type", "?"), + user_id=kwargs["user_id"], + cause=RuntimeError("finalize failed"), + ) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -176,11 +189,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp call["quota_reserve_micros_override"] == app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS ) - assert call["thread_id"] == 99 + # Background artifact audit rows intentionally omit the TokenUsage.thread_id + # FK to avoid coupling Celery audit commits to an active chat transaction. + assert "thread_id" not in call assert call["call_details"] == { "video_presentation_id": 11, "title": "Test Presentation", + "thread_id": 99, } + assert callable(call["billable_session_factory"]) @pytest.mark.asyncio @@ -280,6 +297,57 @@ async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch assert graph_invoked == [] +@pytest.mark.asyncio +async def test_billing_settlement_failure_marks_video_failed(monkeypatch): + from app.db import VideoPresentationStatus + from app.tasks.celery_tasks import video_presentation_tasks + + video = _make_video(video_id=14) + session = _FakeSession(video) + monkeypatch.setattr( + video_presentation_tasks, + "get_celery_session_maker", + lambda: _FakeSessionMaker(session), + ) + + async def _fake_resolver(sess, search_space_id, *, thread_id=None): + return uuid4(), "premium", "gpt-5.4" + + monkeypatch.setattr( + video_presentation_tasks, + "_resolve_agent_billing_for_search_space", + _fake_resolver, + ) + monkeypatch.setattr( + video_presentation_tasks, + "billable_call", + _settlement_failing_billable_call, + ) + + async def _fake_graph_invoke(state, config): + return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []} + + monkeypatch.setattr( + video_presentation_tasks.video_presentation_graph, + "ainvoke", + _fake_graph_invoke, + ) + + result = await video_presentation_tasks._generate_video_presentation( + video_presentation_id=14, + source_content="content", + search_space_id=777, + user_prompt=None, + ) + + assert result == { + "status": "failed", + "video_presentation_id": 14, + "reason": "billing_settlement_failed", + } + assert video.status == VideoPresentationStatus.FAILED + + @pytest.mark.asyncio async def test_resolver_failure_marks_video_failed(monkeypatch): from app.db import VideoPresentationStatus diff --git a/surfsense_web/components/assistant-ui/assistant-message.tsx b/surfsense_web/components/assistant-ui/assistant-message.tsx index ffb0e4dc8..3b9d9a526 100644 --- a/surfsense_web/components/assistant-ui/assistant-message.tsx +++ b/surfsense_web/components/assistant-ui/assistant-message.tsx @@ -477,9 +477,7 @@ const MessageInfoDropdown: FC = () => { {counts.total_tokens.toLocaleString()} tokens - {costMicros && costMicros > 0 - ? ` · ${formatTurnCost(costMicros)}` - : ""} + {costMicros && costMicros > 0 ? ` · ${formatTurnCost(costMicros)}` : ""} ); diff --git a/surfsense_web/components/new-chat/model-selector.tsx b/surfsense_web/components/new-chat/model-selector.tsx index 1a0f8c5ba..44f3feb7a 100644 --- a/surfsense_web/components/new-chat/model-selector.tsx +++ b/surfsense_web/components/new-chat/model-selector.tsx @@ -19,6 +19,7 @@ import { import type React from "react"; import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; +import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom"; import { globalImageGenConfigsAtom, imageGenConfigsAtom, @@ -461,6 +462,18 @@ export function ModelSelector({ const { data: visionUserConfigs, isLoading: visionUserLoading } = useAtomValue(visionLLMConfigsAtom); + // Pending image attachments on the composer. Used to surface an + // amber "No image" hint on chat models the catalog reports as + // non-vision (`supports_image_input=false`) when the next message + // will carry an image. The hint is purely advisory: selection, + // focus, and click handling are unaffected. The backend's safety + // net (`is_known_text_only_chat_model`) is the actual block, and + // it only fires when LiteLLM *explicitly* marks a model as + // text-only — so a model that's secretly capable but hasn't been + // annotated will still flow through to the provider. + const pendingUserImageUrls = useAtomValue(pendingUserImageDataUrlsAtom); + const hasPendingImages = pendingUserImageUrls.length > 0; + const isLoading = llmUserLoading || llmGlobalLoading || @@ -984,6 +997,21 @@ export function ModelSelector({ const isSelected = getSelectedId() === config.id; const isFocused = focusedIndex === index; const hasCitations = "citations_enabled" in config && !!config.citations_enabled; + // Chat-tab only: surface an amber "No image" hint when the + // composer carries images and the catalog reports the model as + // non-vision. This is purely advisory — selection is *not* + // blocked. The backend's narrow safety net + // (`is_known_text_only_chat_model`) is the source of truth for + // rejecting image turns, and it only fires when LiteLLM + // explicitly marks the model as text-only. A model surfaced as + // `supports_image_input=false` here may still be capable in + // practice (unknown / unmapped LiteLLM entry), so we let the + // user pick it and the provider response decide. + const isImageIncompatibleChatModel = + activeTab === "llm" && + hasPendingImages && + "supports_image_input" in config && + (config as Record).supports_image_input === false; return (
handleSelectItem(item)} onKeyDown={ isMobile @@ -1005,9 +1038,8 @@ export function ModelSelector({ } onMouseEnter={() => setFocusedIndex(index)} className={cn( - "group flex items-center gap-2.5 px-3 py-2 rounded-xl cursor-pointer", - "transition-all duration-150 mx-2", - "hover:bg-accent/40", + "group flex items-center gap-2.5 px-3 py-2 rounded-xl", + "transition-all duration-150 mx-2 cursor-pointer hover:bg-accent/40", isSelected && "bg-primary/6 dark:bg-primary/8", isFocused && "bg-accent/50" )} @@ -1053,6 +1085,14 @@ export function ModelSelector({ Free ) : null} + {isImageIncompatibleChatModel && ( + + No image + + )}
diff --git a/surfsense_web/components/pricing/pricing-section.tsx b/surfsense_web/components/pricing/pricing-section.tsx index 127b79167..156ef9134 100644 --- a/surfsense_web/components/pricing/pricing-section.tsx +++ b/surfsense_web/components/pricing/pricing-section.tsx @@ -250,8 +250,8 @@ function PricingFAQ() { Frequently Asked Questions

- Everything you need to know about SurfSense pages, premium credit, and billing. - Can't find what you need? Reach out at{" "} + Everything you need to know about SurfSense pages, premium credit, and billing. Can't + find what you need? Reach out at{" "} rohan@surfsense.com diff --git a/surfsense_web/components/settings/image-model-manager.tsx b/surfsense_web/components/settings/image-model-manager.tsx index ced97464e..d4afa698b 100644 --- a/surfsense_web/components/settings/image-model-manager.tsx +++ b/surfsense_web/components/settings/image-model-manager.tsx @@ -22,6 +22,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; @@ -190,8 +191,7 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { ? "model" : "models"} {" "} - available from your administrator.{" "} - {(() => { + available from your administrator. {(() => { const nonAuto = globalConfigs.filter( (g) => !("is_auto_mode" in g && g.is_auto_mode) ); @@ -214,6 +214,75 @@ export function ImageModelManager({ searchSpaceId }: ImageModelManagerProps) { )} + {/* Global Image Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( +

+

+ Global Image Models +

+
+ {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + + +
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })} +
+
+

+ {cfg.name} +

+ {isPremium ? ( + + Premium + + ) : ( + + Free + + )} +
+
+ {cfg.description && ( +

+ {cfg.description} +

+ )} +
+ + {cfg.model_name} + +
+
+
+ ); + })} +
+
+ )} + {/* Loading Skeleton */} {isLoading && (
diff --git a/surfsense_web/components/settings/more-pages-content.tsx b/surfsense_web/components/settings/more-pages-content.tsx index 8de61b0c7..5635c3314 100644 --- a/surfsense_web/components/settings/more-pages-content.tsx +++ b/surfsense_web/components/settings/more-pages-content.tsx @@ -70,9 +70,7 @@ export function MorePagesContent() {

Get Free Pages

-

- Earn bonus pages by completing tasks -

+

Earn bonus pages by completing tasks

diff --git a/surfsense_web/components/settings/vision-model-manager.tsx b/surfsense_web/components/settings/vision-model-manager.tsx index 886d71008..34aa531fd 100644 --- a/surfsense_web/components/settings/vision-model-manager.tsx +++ b/surfsense_web/components/settings/vision-model-manager.tsx @@ -22,6 +22,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; @@ -191,8 +192,7 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { ? "model" : "models"} {" "} - available from your administrator.{" "} - {(() => { + available from your administrator. {(() => { const nonAuto = globalConfigs.filter( (g) => !("is_auto_mode" in g && g.is_auto_mode) ); @@ -215,6 +215,75 @@ export function VisionModelManager({ searchSpaceId }: VisionModelManagerProps) { )} + {/* Global Vision Models — read-only cards with per-model Free/Premium + badges. Mirrors the badge palette used by the chat role selector + (`llm-role-manager.tsx`) so the meaning is consistent across + every model-configuration surface (chat / image / vision). */} + {!isLoading && + globalConfigs.filter((g) => !("is_auto_mode" in g && g.is_auto_mode)).length > 0 && ( +
+

+ Global Vision Models +

+
+ {globalConfigs + .filter((g) => !("is_auto_mode" in g && g.is_auto_mode)) + .map((cfg) => { + const billingTier = + ("billing_tier" in cfg && + typeof (cfg as { billing_tier?: string }).billing_tier === "string" && + (cfg as { billing_tier?: string }).billing_tier) || + "free"; + const isPremium = billingTier === "premium"; + return ( + + +
+
+ {getProviderIcon(cfg.provider, { className: "size-4" })} +
+
+

+ {cfg.name} +

+ {isPremium ? ( + + Premium + + ) : ( + + Free + + )} +
+
+ {cfg.description && ( +

+ {cfg.description} +

+ )} +
+ + {cfg.model_name} + +
+
+
+ ); + })} +
+
+ )} + {isLoading && (
diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx index 02f53efad..e8fff2873 100644 --- a/surfsense_web/components/tool-ui/generate-podcast.tsx +++ b/surfsense_web/components/tool-ui/generate-podcast.tsx @@ -416,9 +416,19 @@ export const GeneratePodcastToolUI = ({ return ; } - // Already generating - show simple warning, don't create another poller - // The FIRST tool call will display the podcast when ready - // (new: "generating", legacy: "already_generating") + // Pending/generating rows have a stable podcast_id, so the card can poll + // independently while the chat stream finishes. + if ( + (result.status === "pending" || + result.status === "generating" || + result.status === "processing") && + result.podcast_id + ) { + return ; + } + + // Legacy duplicate/no-ID result - show a simple warning, don't create + // another poller. The first tool call will display the podcast when ready. if (result.status === "generating" || result.status === "already_generating") { return (
@@ -432,11 +442,6 @@ export const GeneratePodcastToolUI = ({ ); } - // Pending - poll for completion (new: "pending" with podcast_id) - if (result.status === "pending" && result.podcast_id) { - return ; - } - // Ready with podcast_id (new: "ready", legacy: "success") if ((result.status === "ready" || result.status === "success") && result.podcast_id) { return ; diff --git a/surfsense_web/contexts/login-gate.tsx b/surfsense_web/contexts/login-gate.tsx index 790e5c00e..f72cb3a42 100644 --- a/surfsense_web/contexts/login-gate.tsx +++ b/surfsense_web/contexts/login-gate.tsx @@ -44,8 +44,8 @@ export function LoginGateProvider({ children }: { children: ReactNode }) { Create a free account to {feature} - Get $5 of premium credit, save chat history, upload documents, use all AI tools, - and connect 30+ integrations. + Get $5 of premium credit, save chat history, upload documents, use all AI tools, and + connect 30+ integrations. diff --git a/surfsense_web/contracts/types/new-llm-config.types.ts b/surfsense_web/contracts/types/new-llm-config.types.ts index 2d6b70eda..b52b98ae4 100644 --- a/surfsense_web/contracts/types/new-llm-config.types.ts +++ b/surfsense_web/contracts/types/new-llm-config.types.ts @@ -65,6 +65,13 @@ export const newLLMConfig = z.object({ created_at: z.string(), search_space_id: z.number(), user_id: z.string(), + + // Capability flag — derived server-side at the route boundary from + // LiteLLM's authoritative model map. There is no DB column. Default + // `true` is the conservative-allow stance for unknown / unmapped + // BYOK rows; the streaming-task safety net is the only place a + // `false` actually blocks a request. + supports_image_input: z.boolean().default(true), }); /** @@ -74,11 +81,16 @@ export const newLLMConfigPublic = newLLMConfig.omit({ api_key: true }); /** * Create NewLLMConfig + * + * `supports_image_input` is omitted because it is derived server-side + * from LiteLLM's model map at read time — there is no DB column to + * persist a client-supplied value into. */ export const createNewLLMConfigRequest = newLLMConfig.omit({ id: true, created_at: true, user_id: true, + supports_image_input: true, }); export const createNewLLMConfigResponse = newLLMConfig; @@ -114,6 +126,8 @@ export const updateNewLLMConfigRequest = z.object({ created_at: true, search_space_id: true, user_id: true, + // Derived server-side; not part of the writable surface. + supports_image_input: true, }) .partial(), }); @@ -172,6 +186,16 @@ export const globalNewLLMConfig = z.object({ seo_title: z.string().nullable().optional(), seo_description: z.string().nullable().optional(), quota_reserve_tokens: z.number().nullable().optional(), + // Capability flag — true when the model can accept image inputs. + // Resolved server-side (OpenRouter dynamic configs use the OR + // `architecture.input_modalities` field; YAML / BYOK use LiteLLM's + // authoritative `supports_vision` map). The chat selector renders + // an amber "No image" hint when this is false and there are + // pending image attachments, but does not block selection — the + // backend safety net only rejects when LiteLLM *explicitly* marks + // the model as text-only, so unknown / new models still flow + // through. Default `true` matches that conservative-allow stance. + supports_image_input: z.boolean().default(true), }); export const getGlobalNewLLMConfigsResponse = z.array(globalNewLLMConfig); @@ -259,6 +283,9 @@ export const globalImageGenConfig = z.object({ is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for image-gen too. + is_premium: z.boolean().default(false), quota_reserve_micros: z.number().nullable().optional(), }); @@ -341,6 +368,9 @@ export const globalVisionLLMConfig = z.object({ is_global: z.literal(true), is_auto_mode: z.boolean().optional().default(false), billing_tier: z.string().default("free"), + // Mirrors `globalNewLLMConfig.is_premium` so the new-chat selector's + // Free/Premium badge logic lights up automatically for vision too. + is_premium: z.boolean().default(false), quota_reserve_tokens: z.number().nullable().optional(), input_cost_per_token: z.number().nullable().optional(), output_cost_per_token: z.number().nullable().optional(), diff --git a/surfsense_web/next.config.ts b/surfsense_web/next.config.ts index 5414d548d..6cfcb5187 100644 --- a/surfsense_web/next.config.ts +++ b/surfsense_web/next.config.ts @@ -18,6 +18,12 @@ const nextConfig: NextConfig = { }, images: { remotePatterns: [ + { + protocol: "http", + hostname: "localhost", + port: "8000", + pathname: "/api/v1/image-generations/**", + }, { protocol: "https", hostname: "**",