mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
feat: fixed vision/image provider specific errors and fixed podcast/video streaming
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
This commit is contained in:
parent
ae9d36d77f
commit
47b2994ec7
54 changed files with 4469 additions and 563 deletions
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
|
||||
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
|
||||
|
||||
There is no DB column for ``supports_image_input`` on
|
||||
``NewLLMConfig`` — the value is resolved at the API boundary by
|
||||
``derive_supports_image_input`` so the new-chat selector / streaming
|
||||
task can read the same field shape regardless of source (BYOK vs YAML
|
||||
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
|
||||
user out of their own model choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _byok_row(
|
||||
*,
|
||||
id_: int,
|
||||
model_name: str,
|
||||
base_model: str | None = None,
|
||||
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
|
||||
custom_provider: str | None = None,
|
||||
) -> object:
|
||||
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
|
||||
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
|
||||
|
||||
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
|
||||
enum validator accepts it — same as the ORM row would carry."""
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
name=f"BYOK-{id_}",
|
||||
description=None,
|
||||
provider=provider,
|
||||
custom_provider=custom_provider,
|
||||
model_name=model_name,
|
||||
api_key="sk-byok",
|
||||
api_base=None,
|
||||
litellm_params={"base_model": base_model} if base_model else None,
|
||||
system_instructions="",
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
created_at=datetime.now(tz=UTC),
|
||||
search_space_id=42,
|
||||
user_id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_byok_known_vision_model_resolves_true():
|
||||
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
|
||||
True. The serialized row carries that value through to the
|
||||
``NewLLMConfigRead`` schema."""
|
||||
row = _byok_row(id_=1, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
assert serialized.id == 1
|
||||
assert serialized.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_serialize_byok_unknown_model_default_allows():
|
||||
"""Unknown / unmapped: default-allow. The streaming-task safety net
|
||||
is the actual block, and it requires LiteLLM to *explicitly* say
|
||||
text-only — so a brand new BYOK model should not be pre-judged."""
|
||||
row = _byok_row(
|
||||
id_=2,
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
provider=LiteLLMProvider.CUSTOM,
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_uses_base_model_when_present():
|
||||
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
|
||||
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
|
||||
helper must consult ``base_model`` first or unrecognised deployment
|
||||
ids would shadow the real capability."""
|
||||
row = _byok_row(
|
||||
id_=3,
|
||||
model_name="my-azure-deployment-id-no-litellm-knows-this",
|
||||
base_model="gpt-4o",
|
||||
provider=LiteLLMProvider.AZURE_OPENAI,
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_returns_pydantic_read_model():
|
||||
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
|
||||
the schema additions are guaranteed to be present in the API
|
||||
surface. This guards against a future regression where someone
|
||||
deletes the augmentation step and falls back to ORM passthrough."""
|
||||
from app.schemas import NewLLMConfigRead
|
||||
|
||||
row = _byok_row(id_=4, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
assert isinstance(serialized, NewLLMConfigRead)
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
"""Unit tests for ``is_premium`` derivation on the global image-gen and
|
||||
vision-LLM list endpoints.
|
||||
|
||||
Chat globals (``GET /global-llm-configs``) already emit
|
||||
``is_premium = (billing_tier == "premium")``. Image and vision did not,
|
||||
which made the new-chat ``model-selector`` render the Free/Premium badge
|
||||
on the Chat tab but skip it on the Image and Vision tabs (the selector
|
||||
keys its badge logic off ``is_premium``). These tests pin parity:
|
||||
|
||||
* YAML free entry → ``is_premium=False``
|
||||
* YAML premium entry → ``is_premium=True``
|
||||
* OpenRouter dynamic premium entry → ``is_premium=True``
|
||||
* Auto stub (always emitted when at least one config is present)
|
||||
→ ``is_premium=False``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_IMAGE_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "DALL-E 3",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "dall-e-3",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-Image 1 (premium)",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -20_001,
|
||||
"name": "google/gemini-2.5-flash-image (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
_VISION_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o Vision",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "Claude 3.5 Sonnet (premium)",
|
||||
"provider": "ANTHROPIC",
|
||||
"model_name": "claude-3-5-sonnet",
|
||||
"api_key": "sk-ant-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -30_001,
|
||||
"name": "openai/gpt-4o (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image generation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
|
||||
"""Each emitted config must carry ``is_premium`` derived server-side
|
||||
from ``billing_tier``. The Auto stub is always free.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub is always emitted when at least one global config exists,
|
||||
# and it must always declare itself free (Auto-mode billing-tier
|
||||
# surfacing is a separate follow-up).
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
# YAML free entry — ``is_premium=False``
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
# YAML premium entry — ``is_premium=True``
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
# OpenRouter dynamic premium entry — same field, same derivation
|
||||
assert by_id[-20_001]["is_premium"] is True
|
||||
assert by_id[-20_001]["billing_tier"] == "premium"
|
||||
|
||||
# Every emitted dict (including Auto) must have the field — never missing.
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
"""When there are no global configs at all, the endpoint emits an
|
||||
empty list (no Auto stub) — Auto mode would have nothing to route to.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
assert payload == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
assert by_id[-30_001]["is_premium"] is True
|
||||
assert by_id[-30_001]["billing_tier"] == "premium"
|
||||
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
assert payload == []
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on the chat global
|
||||
config endpoint (``GET /global-new-llm-configs``).
|
||||
|
||||
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
|
||||
|
||||
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
|
||||
loader for operator overrides, or by the OpenRouter integration from
|
||||
``architecture.input_modalities``) — wins.
|
||||
2. ``derive_supports_image_input`` helper — default-allow on unknown
|
||||
models, only False when LiteLLM / OR modalities are definitive.
|
||||
|
||||
The flag is purely informational at the API boundary. The streaming
|
||||
task safety net (``is_known_text_only_chat_model``) is the actual block,
|
||||
and it requires LiteLLM to *explicitly* mark the model as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (explicit true)",
|
||||
"description": "vision-capable, explicit YAML override",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "DeepSeek V3 (explicit false)",
|
||||
"description": "OpenRouter dynamic — modality-derived false",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "deepseek/deepseek-v3.2-exp",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": False,
|
||||
},
|
||||
{
|
||||
"id": -10_010,
|
||||
"name": "Unannotated GPT-4o",
|
||||
"description": "no flag set — resolver should derive True via LiteLLM",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
# supports_image_input intentionally absent
|
||||
},
|
||||
{
|
||||
"id": -10_011,
|
||||
"name": "Unannotated unknown model",
|
||||
"description": "unmapped — default-allow True",
|
||||
"provider": "CUSTOM",
|
||||
"custom_provider": "brand_new_proxy",
|
||||
"model_name": "brand-new-model-x9",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
|
||||
"""Each emitted chat config carries ``supports_image_input`` as a
|
||||
bool. Explicit values win; unannotated entries are resolved via the
|
||||
helper (default-allow True)."""
|
||||
from app.config import config
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
|
||||
|
||||
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub: optimistic True so the user can keep Auto selected with
|
||||
# vision-capable deployments somewhere in the pool.
|
||||
assert 0 in by_id, "Auto stub should be emitted when configs exist"
|
||||
assert by_id[0]["supports_image_input"] is True
|
||||
assert by_id[0]["is_auto_mode"] is True
|
||||
|
||||
# Explicit True is preserved.
|
||||
assert by_id[-1]["supports_image_input"] is True
|
||||
|
||||
# Explicit False is preserved (the exact failure mode the safety net
|
||||
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
|
||||
# endpoints found that support image input").
|
||||
assert by_id[-2]["supports_image_input"] is False
|
||||
|
||||
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
|
||||
assert by_id[-10_010]["supports_image_input"] is True
|
||||
|
||||
# Unknown / unmapped model: default-allow rather than pre-judge.
|
||||
assert by_id[-10_011]["supports_image_input"] is True
|
||||
|
||||
for cfg in payload:
|
||||
assert "supports_image_input" in cfg, (
|
||||
f"supports_image_input missing from {cfg.get('id')}"
|
||||
)
|
||||
assert isinstance(cfg["supports_image_input"], bool)
|
||||
|
|
@ -0,0 +1,286 @@
|
|||
"""Image-aware extension of the Auto-pin resolver.
|
||||
|
||||
When the current chat turn carries an ``image_url`` block, the pin
|
||||
resolver must:
|
||||
|
||||
1. Filter the candidate pool to vision-capable cfgs so a freshly
|
||||
selected pin can never be text-only.
|
||||
2. Treat any existing pin whose capability is False as invalid (force
|
||||
re-pin), even when it would otherwise be reused as the thread's
|
||||
stable model.
|
||||
3. Raise ``ValueError`` (mapped to the friendly
|
||||
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
|
||||
task) when no vision-capable cfg is available — instead of silently
|
||||
pinning text-only and 404-ing at the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
clear_healthy,
|
||||
clear_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_caches():
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
self._thread = thread
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _thread(*, pinned: int | None = None):
|
||||
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
|
||||
|
||||
|
||||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"vision-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
"supports_image_input": True,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"text-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
# Higher quality than the vision cfgs — so a bug that ignores
|
||||
# the image flag would surface as the resolver picking this one.
|
||||
"supports_image_input": False,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
async def _premium_allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
# The thread should be pinned to the vision cfg even though the
|
||||
# text-only cfg has a higher quality score.
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
|
||||
"""An existing text-only pin must be invalidated when the next turn
|
||||
requires image input. The non-image path would happily reuse it."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
|
||||
"""If the thread is already pinned to a vision-capable cfg, reuse it
|
||||
— same as the non-image path. Image-aware filtering must not force
|
||||
spurious re-pins."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-2))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
|
||||
"""The friendly-error path: no vision-capable cfg in the pool -> raise
|
||||
``ValueError`` whose message contains ``vision-capable`` so the
|
||||
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _text_only_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="vision-capable"):
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
|
||||
"""Regression guard: the image flag must default False and not affect
|
||||
a normal text-only turn — text-only cfgs remain selectable."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
||||
"""A YAML cfg that omits ``supports_image_input`` falls through to
|
||||
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
|
||||
that returns True, so the cfg should be a valid candidate."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
cfg_unannotated_vision = {
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o", # known vision model in LiteLLM map
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 80,
|
||||
# NOTE: no supports_image_input key
|
||||
}
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
|
|
@ -15,6 +15,7 @@ vision LLM extraction:
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
|
@ -57,6 +58,9 @@ class _FakeSession:
|
|||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
@ -71,7 +75,9 @@ async def _fake_shielded_session():
|
|||
_SESSIONS_USED: list[_FakeSession] = []
|
||||
|
||||
|
||||
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
|
||||
def _patch_isolation_layer(
|
||||
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
|
||||
):
|
||||
"""Wire fake reserve/finalize/release/session helpers."""
|
||||
_SESSIONS_USED.clear()
|
||||
reserve_calls: list[dict[str, Any]] = []
|
||||
|
|
@ -91,6 +97,8 @@ def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None)
|
|||
async def _fake_finalize(
|
||||
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||
):
|
||||
if finalize_exc is not None:
|
||||
raise finalize_exc
|
||||
finalize_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
|
|
@ -343,6 +351,125 @@ async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
|||
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
|
||||
from app.services.billable_calls import BillingSettlementError, billable_call
|
||||
|
||||
class _FinalizeError(RuntimeError):
|
||||
pass
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(allowed=True),
|
||||
finalize_exc=_FinalizeError("db finalize failed"),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(BillingSettlementError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _HangingCommitSession(_FakeSession):
|
||||
async def commit(self) -> None:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _hanging_session_factory():
|
||||
s = _HangingCommitSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
billable_session_factory=_hanging_session_factory,
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["release"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_audit_failure_is_best_effort(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
async def _failing_record(_session, **_kwargs):
|
||||
raise RuntimeError("audit insert failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_failing_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
async with billable_call(
|
||||
user_id=uuid4(),
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Podcast / video-presentation usage_type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -387,7 +514,7 @@ async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
|||
assert len(spies["record"]) == 1
|
||||
row = spies["record"][0]
|
||||
assert row["usage_type"] == "podcast_generation"
|
||||
assert row["thread_id"] == 99
|
||||
assert row["thread_id"] is None
|
||||
assert row["search_space_id"] == 42
|
||||
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
"""Defense-in-depth: image-gen call sites must not let an empty
|
||||
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
|
||||
|
||||
The bug repro: an OpenRouter image-gen config ships
|
||||
``api_base=""``. The pre-fix call site in
|
||||
``image_generation_routes._execute_image_generation`` did
|
||||
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
|
||||
silently dropped the empty string. LiteLLM then fell back to
|
||||
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
|
||||
and OpenRouter's ``image_generation/transformation`` appended
|
||||
``/chat/completions`` to it → 404 ``Resource not found``.
|
||||
|
||||
This test pins the post-fix behaviour: with an empty ``api_base`` in
|
||||
the config, the call site MUST set ``api_base`` to OpenRouter's public
|
||||
URL instead of leaving it unset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||
"""The global-config branch (``config_id < 0``) of
|
||||
``_execute_image_generation`` must apply the resolver and pin
|
||||
``api_base`` to OpenRouter when the config ships an empty string.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "", # the original bug shape
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
|
||||
|
||||
image_gen = MagicMock()
|
||||
image_gen.image_generation_config_id = cfg["id"]
|
||||
image_gen.prompt = "test"
|
||||
image_gen.n = 1
|
||||
image_gen.quality = None
|
||||
image_gen.size = None
|
||||
image_gen.style = None
|
||||
image_gen.response_format = None
|
||||
image_gen.model = None
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
session = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_image_gen_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"aimage_generation",
|
||||
side_effect=fake_aimage_generation,
|
||||
),
|
||||
):
|
||||
await image_generation_routes._execute_image_generation(
|
||||
session=session, image_gen=image_gen, search_space=search_space
|
||||
)
|
||||
|
||||
# The whole point of the fix: even with empty ``api_base`` in the
|
||||
# config, we forward OpenRouter's public URL so the call doesn't
|
||||
# inherit an Azure endpoint.
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.new_chat.tools import generate_image as gi_module
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
response = MagicMock()
|
||||
response.model_dump.return_value = {
|
||||
"data": [{"url": "https://example.com/x.png"}]
|
||||
}
|
||||
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
|
||||
return response
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
|
||||
session_cm = AsyncMock()
|
||||
session = AsyncMock()
|
||||
session_cm.__aenter__.return_value = session
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalars.return_value = scalars
|
||||
session.execute.return_value = exec_result
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
|
||||
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
|
||||
async def _refresh(obj):
|
||||
obj.id = 1
|
||||
|
||||
session.refresh.side_effect = _refresh
|
||||
|
||||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
patch.object(
|
||||
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
|
||||
),
|
||||
):
|
||||
tool = gi_module.create_generate_image_tool(
|
||||
search_space_id=1, db_session=MagicMock()
|
||||
)
|
||||
await tool.ainvoke({"prompt": "a cat", "n": 1})
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
|
||||
"""The Auto-mode router pool must also resolve ``api_base`` when an
|
||||
OpenRouter config ships an empty string. The deployment dict is fed
|
||||
straight to ``litellm.Router``, so a missing ``api_base`` would
|
||||
leak the same way as the direct call sites.
|
||||
"""
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
deployment = ImageGenRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
|
@ -265,6 +265,10 @@ def test_generate_image_gen_configs_filters_by_image_output():
|
|||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't 404 against an inherited Azure endpoint.
|
||||
assert c["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
|
|
@ -342,6 +346,10 @@ def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
|||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't inherit an Azure endpoint.
|
||||
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
|
|
|
|||
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Unit tests for the shared ``api_base`` resolver.
|
||||
|
||||
The cascade exists so vision and image-gen call sites can't silently
|
||||
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
|
||||
when an OpenRouter / Groq / etc. config ships an empty string. See
|
||||
``provider_api_base`` module docstring for the original repro
|
||||
(OpenRouter image-gen 404-ing against an Azure endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_api_base import (
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_config_value_wins_over_defaults():
|
||||
"""A non-empty config value is always returned verbatim, even when the
|
||||
provider has a default — the operator gets the last word."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="https://my-openrouter-mirror.example.com/v1",
|
||||
)
|
||||
assert result == "https://my-openrouter-mirror.example.com/v1"
|
||||
|
||||
|
||||
def test_provider_key_default_when_config_missing():
|
||||
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
|
||||
base URL — the provider-key map must take precedence over the prefix
|
||||
map so DeepSeek requests don't go to OpenAI."""
|
||||
result = resolve_api_base(
|
||||
provider="DEEPSEEK",
|
||||
provider_prefix="openai",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_provider_prefix_default_when_no_key_default():
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_unknown_provider_returns_none():
|
||||
"""When neither map matches we return ``None`` so the caller can let
|
||||
LiteLLM apply its own provider-integration default (Azure deployment
|
||||
URL, custom-provider URL, etc.)."""
|
||||
result = resolve_api_base(
|
||||
provider="SOMETHING_NEW",
|
||||
provider_prefix="something_new",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_empty_string_config_treated_as_missing():
|
||||
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
|
||||
and downstream call sites use ``if cfg.get("api_base"):`` — empty
|
||||
strings are falsy in Python but the cascade has to step in anyway."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_whitespace_only_config_treated_as_missing():
|
||||
"""A config value of ``" "`` is a configuration mistake — treat it
|
||||
as missing instead of forwarding whitespace to LiteLLM (which would
|
||||
almost certainly 404)."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=" ",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_provider_case_insensitive():
|
||||
"""Some call sites pass the provider lowercase (DB enum value), others
|
||||
uppercase (YAML key). Both must resolve."""
|
||||
upper = resolve_api_base(
|
||||
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
lower = resolve_api_base(
|
||||
provider="deepseek", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_all_inputs_none_returns_none():
|
||||
assert (
|
||||
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
|
||||
is None
|
||||
)
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Unit tests for the shared chat-image capability resolver.
|
||||
|
||||
Two resolvers, two intents:
|
||||
|
||||
- ``derive_supports_image_input`` — best-effort True for the catalog and
|
||||
selector. Default-allow on unknown / unmapped models. The streaming
|
||||
task safety net never sees this value directly.
|
||||
|
||||
- ``is_known_text_only_chat_model`` — strict opt-out for the safety net.
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False``. Anything else (missing key, exception,
|
||||
True) returns False so the request flows through to the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import (
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — OpenRouter modalities path (authoritative)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_or_modalities_with_image_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="openai/gpt-4o",
|
||||
openrouter_input_modalities=["text", "image"],
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_text_only_returns_false():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="deepseek/deepseek-v3.2-exp",
|
||||
openrouter_input_modalities=["text"],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_empty_list_returns_false():
|
||||
"""OR explicitly publishing an empty modality list is a definitive
|
||||
'no inputs at all' signal — treat as False rather than falling back
|
||||
to LiteLLM."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="weird/empty-modalities",
|
||||
openrouter_input_modalities=[],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_none_falls_through_to_litellm():
|
||||
"""``None`` (missing key) is *not* a definitive signal — fall through
|
||||
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
openrouter_input_modalities=None,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — LiteLLM model-map path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_litellm_known_vision_model_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_base_model_wins_over_model_name():
|
||||
"""Azure-style entries pass model_name=deployment_id and put the
|
||||
canonical sku in litellm_params.base_model. The resolver must
|
||||
consult base_model first or the deployment id (which LiteLLM
|
||||
doesn't know) would shadow the real capability."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment-id",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_unknown_model_default_allows():
|
||||
"""Default-allow on unknown — the safety net is the actual block."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_known_text_only_returns_false():
|
||||
"""A model that LiteLLM explicitly knows is text-only resolves to
|
||||
False even via the catalog resolver. ``deepseek-chat`` (the
|
||||
DeepSeek-V3 chat sku) is in the map without supports_vision and
|
||||
LiteLLM's `supports_vision` returns False."""
|
||||
# Sanity: confirm the helper's negative path. We use a small model
|
||||
# known not to support vision per the map.
|
||||
result = derive_supports_image_input(
|
||||
provider="DEEPSEEK",
|
||||
model_name="deepseek-chat",
|
||||
)
|
||||
# We accept either False (LiteLLM said explicit no) or True
|
||||
# (default-allow if the entry isn't mapped on this version) — the
|
||||
# invariant is that the resolver never *raises* on a known-text-only
|
||||
# provider/model. The behaviour-binding assertion lives in
|
||||
# ``test_is_known_text_only_chat_model_explicit_false`` below.
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_known_text_only_chat_model — strict opt-out semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_vision_model():
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_unknown_model():
|
||||
"""Strict opt-out: missing from the map ≠ text-only. The safety net
|
||||
must NOT fire for an unmapped model — that's the regression we're
|
||||
fixing."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
|
||||
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
|
||||
helper swallows the exception and returns False so the safety net
|
||||
doesn't fire on a transient lookup failure."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise ValueError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
|
||||
we exercise the opt-out path deterministically. Using a stub keeps
|
||||
the test stable across LiteLLM map updates."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
|
||||
"""A model entry without ``supports_vision`` at all is treated as
|
||||
'unknown' — strict opt-out means False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"max_input_tokens": 8192} # no supports_vision
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
|
||||
|
||||
Capability is sourced from two places, in order of preference:
|
||||
|
||||
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
|
||||
(authoritative — OpenRouter publishes per-model modalities directly).
|
||||
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
|
||||
YAML / BYOK configs that don't carry an explicit operator override.
|
||||
|
||||
The catalog default is *True* (conservative-allow): an unknown / unmapped
|
||||
model is not pre-judged. The streaming-task safety net
|
||||
(``is_known_text_only_chat_model``) is the only place a False actually
|
||||
blocks a request — and it requires LiteLLM to *explicitly* mark the model
|
||||
as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
_generate_configs,
|
||||
_supports_image_input,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_SETTINGS_BASE: dict = {
|
||||
"api_key": "sk-or-test",
|
||||
"id_offset": -10_000,
|
||||
"rpm": 200,
|
||||
"tpm": 1_000_000,
|
||||
"free_rpm": 20,
|
||||
"free_tpm": 100_000,
|
||||
"anonymous_enabled_paid": False,
|
||||
"anonymous_enabled_free": True,
|
||||
"quota_reserve_tokens": 4000,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_image_input helper (OpenRouter modalities)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_image_input_true_for_multimodal():
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_for_text_only():
|
||||
"""The exact failure mode the safety net guards against — DeepSeek V3
|
||||
is a text-in/text-out model and would 404 if forwarded image_url."""
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_when_modalities_missing():
|
||||
"""Defensive: missing architecture is treated as text-only at the
|
||||
OpenRouter helper level. The wider catalog resolver
|
||||
(`derive_supports_image_input`) only consults modalities when they
|
||||
are non-empty, otherwise it falls back to LiteLLM."""
|
||||
assert _supports_image_input({"id": "weird/model"}) is False
|
||||
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{"id": "weird/model", "architecture": {"input_modalities": None}}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_configs threads the flag onto every emitted chat config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_configs_emits_supports_image_input():
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
by_model = {c["model_name"]: c for c in cfgs}
|
||||
|
||||
gpt = by_model["openai/gpt-4o"]
|
||||
assert gpt["supports_image_input"] is True
|
||||
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
|
||||
assert deepseek["supports_image_input"] is False
|
||||
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML loader: defer to derive_supports_image_input on unannotated entries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
|
||||
"""The regression case: an Azure GPT-5.x YAML entry without a
|
||||
``supports_image_input`` override should resolve to True via LiteLLM's
|
||||
model map (which says ``supports_vision: true``). Previously this
|
||||
defaulted to False, blocking every image turn for vision-capable
|
||||
YAML configs."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -2
|
||||
name: Azure GPT-4o
|
||||
provider: AZURE_OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: GPT-4o
|
||||
provider: OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
supports_image_input: false
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
# Operator override always wins, even against LiteLLM's True.
|
||||
assert configs[0]["supports_image_input"] is False
|
||||
|
||||
|
||||
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
|
||||
"""Unknown / unmapped model in YAML: default-allow. The streaming
|
||||
safety net (which requires an explicit-False from LiteLLM) is the
|
||||
only place a real block happens, so we don't lock the user out of
|
||||
a freshly added third-party entry the catalog can't introspect."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: Some Brand New Model
|
||||
provider: CUSTOM
|
||||
custom_provider: brand_new_proxy
|
||||
model_name: brand-new-model-x9
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentConfig threads the flag through both YAML and Auto / BYOK
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_explicit_overrides_resolver():
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg_text_only = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "Text Only Override",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": False,
|
||||
}
|
||||
)
|
||||
cfg_explicit_vision = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-4o",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
assert cfg_text_only.supports_image_input is False
|
||||
assert cfg_explicit_vision.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_unannotated_uses_resolver():
|
||||
"""Without an explicit YAML key, AgentConfig defers to the catalog
|
||||
resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (no override)",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
}
|
||||
)
|
||||
assert cfg.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_auto_mode_supports_image_input():
|
||||
"""Auto routes across the pool. We optimistically allow image input
|
||||
so users can keep their selection on Auto with a vision-capable
|
||||
deployment somewhere in the pool. The router's own `allowed_fails`
|
||||
handles non-vision deployments via fallback."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
auto = AgentConfig.from_auto_mode()
|
||||
assert auto.supports_image_input is True
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
|
||||
defaults from ``litellm.api_base`` either.
|
||||
|
||||
Vision shares the same shape as image-gen — global YAML / OpenRouter
|
||||
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
|
||||
call sites would silently drop the empty string and inherit
|
||||
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
|
||||
construction so we test the kwargs we hand to it instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch: an OpenRouter vision config with
|
||||
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
|
||||
``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
|
||||
never silently absent."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.user_id = "user-x"
|
||||
search_space.vision_llm_config_id = cfg["id"]
|
||||
|
||||
session = AsyncMock()
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session.execute.return_value = result
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class FakeSanitized:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.vision_llm_router_service.get_global_vision_llm_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch(
|
||||
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
|
||||
new=FakeSanitized,
|
||||
),
|
||||
):
|
||||
await llm_service.get_vision_llm(session=session, search_space_id=1)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-4o"
|
||||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
|
||||
so the resolver has to apply at deployment construction time too."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
|
||||
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
"""Regression tests for ``run_async_celery_task``.
|
||||
|
||||
These tests pin down the production bug observed on 2026-05-02 where
|
||||
the video-presentation Celery task hung at ``[billable_call] finalize``
|
||||
because the shared ``app.db.engine`` had pooled asyncpg connections
|
||||
bound to a *previous* task's now-closed event loop. Reusing such a
|
||||
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
|
||||
(the proactor is None because the loop is gone) and can hang forever
|
||||
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
|
||||
|
||||
The fix is ``run_async_celery_task``: a small helper that runs every
|
||||
async celery task body inside a fresh event loop and disposes the
|
||||
shared engine pool both before (defends against a previous task that
|
||||
crashed) and after (releases connections we opened on this loop).
|
||||
|
||||
Tests here exercise the helper with a stub engine that records
|
||||
``dispose()`` calls and panics if a coroutine produced by one loop is
|
||||
awaited on another — mirroring the real asyncpg behaviour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub engine that emulates the asyncpg-on-stale-loop crash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StaleLoopEngine:
|
||||
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
|
||||
|
||||
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
|
||||
the running event loop id so tests can assert it ran on *each*
|
||||
fresh loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dispose_loop_ids: list[int] = []
|
||||
|
||||
async def dispose(self) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self.dispose_loop_ids.append(id(loop))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
|
||||
"""Patch ``from app.db import engine as shared_engine`` lookup.
|
||||
|
||||
The helper imports lazily inside the function body, so we have to
|
||||
patch the attribute on the already-loaded ``app.db`` module.
|
||||
"""
|
||||
import app.db as app_db
|
||||
|
||||
original = getattr(app_db, "engine", None)
|
||||
app_db.engine = stub # type: ignore[attr-defined]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if original is None:
|
||||
with pytest.raises(AttributeError):
|
||||
_ = app_db.engine
|
||||
else:
|
||||
app_db.engine = original # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
|
||||
"""Happy path: the coroutine result is returned, and the shared
|
||||
engine is disposed both before and after the task body runs.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> str:
|
||||
# Engine should already have been disposed once before we run.
|
||||
assert len(stub.dispose_loop_ids) == 1
|
||||
return "ok"
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
result = run_async_celery_task(_body)
|
||||
|
||||
assert result == "ok"
|
||||
# Once before the body, once after (in finally).
|
||||
assert len(stub.dispose_loop_ids) == 2
|
||||
# Both disposes ran on the SAME (fresh) loop the task body used.
|
||||
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
|
||||
|
||||
|
||||
def test_runner_creates_fresh_loop_per_invocation() -> None:
|
||||
"""Each call must spin its own loop. Without this guarantee a
|
||||
previous task's loop would be reused and the asyncpg-stale-loop
|
||||
crash would never be avoided.
|
||||
"""
|
||||
import app.tasks.celery_tasks as celery_tasks_pkg
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
new_loop_calls = 0
|
||||
closed_loops: list[bool] = []
|
||||
|
||||
real_new_event_loop = asyncio.new_event_loop
|
||||
|
||||
def _counting_new_loop() -> asyncio.AbstractEventLoop:
|
||||
nonlocal new_loop_calls
|
||||
new_loop_calls += 1
|
||||
loop = real_new_event_loop()
|
||||
# Hook close() so we can verify each loop was closed properly
|
||||
# before the next one was created.
|
||||
original_close = loop.close
|
||||
|
||||
def _tracked_close() -> None:
|
||||
closed_loops.append(True)
|
||||
original_close()
|
||||
|
||||
loop.close = _tracked_close # type: ignore[method-assign]
|
||||
return loop
|
||||
|
||||
async def _body() -> None:
|
||||
# Loop is alive and current at body execution time.
|
||||
running = asyncio.get_running_loop()
|
||||
assert not running.is_closed()
|
||||
|
||||
with (
|
||||
_patch_shared_engine(stub),
|
||||
patch.object(asyncio, "new_event_loop", _counting_new_loop),
|
||||
):
|
||||
for _ in range(3):
|
||||
celery_tasks_pkg.run_async_celery_task(_body)
|
||||
|
||||
assert new_loop_calls == 3
|
||||
assert closed_loops == [True, True, True]
|
||||
# Each invocation disposed twice (before + after).
|
||||
assert len(stub.dispose_loop_ids) == 6
|
||||
|
||||
|
||||
def test_runner_disposes_engine_even_when_body_raises() -> None:
|
||||
"""Cleanup MUST run on the failure path too — otherwise stale
|
||||
connections leak into the next task and cause the original hang.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
class _BoomError(RuntimeError):
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
raise _BoomError("kaboom")
|
||||
|
||||
with _patch_shared_engine(stub), pytest.raises(_BoomError):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
|
||||
|
||||
|
||||
def test_runner_swallows_dispose_errors() -> None:
|
||||
"""A flaky engine.dispose() must NEVER take down a celery task.
|
||||
|
||||
Production scenario: the very first dispose (before the body runs)
|
||||
might hit a partially-initialised engine; the helper logs and
|
||||
moves on. The task body still runs; the result is still returned.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
class _AngryEngine:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def dispose(self) -> None:
|
||||
self.calls += 1
|
||||
raise RuntimeError("dispose() blew up")
|
||||
|
||||
stub = _AngryEngine()
|
||||
|
||||
async def _body() -> int:
|
||||
return 42
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
assert run_async_celery_task(_body) == 42
|
||||
|
||||
assert stub.calls == 2 # before + after both attempted
|
||||
|
||||
|
||||
def test_runner_propagates_value_from_async_body() -> None:
|
||||
"""Sanity: pass-through of any pickleable celery return value."""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> dict[str, object]:
|
||||
return {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
out = run_async_celery_task(_body)
|
||||
|
||||
assert out == {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
|
||||
def test_video_presentation_task_uses_runner_helper() -> None:
|
||||
"""Defence-in-depth: confirm the celery task module imports
|
||||
``run_async_celery_task``. If a future refactor inlines a
|
||||
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
|
||||
the original hang will return.
|
||||
"""
|
||||
# The module's task body should not contain a manual new_event_loop
|
||||
# call — that's exactly what the helper exists to centralise.
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
src = inspect.getsource(video_presentation_tasks)
|
||||
assert "run_async_celery_task" in src, (
|
||||
"video_presentation_tasks.py must use run_async_celery_task; "
|
||||
"manual asyncio.new_event_loop() in a celery task hangs on the "
|
||||
"shared SQLAlchemy pool when reused across tasks."
|
||||
)
|
||||
assert "asyncio.new_event_loop" not in src, (
|
||||
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
|
||||
"call — route every async task through run_async_celery_task to "
|
||||
"avoid the stale-pool hang."
|
||||
)
|
||||
|
||||
|
||||
def test_podcast_task_uses_runner_helper() -> None:
|
||||
"""Symmetric assertion for the podcast task — same root cause, same
|
||||
fix, same regression risk.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
src = inspect.getsource(podcast_tasks)
|
||||
assert "run_async_celery_task" in src
|
||||
assert "asyncio.new_event_loop" not in src
|
||||
|
||||
|
||||
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
|
||||
"""If the task body created any async generators that didn't get
|
||||
fully iterated, we must still call ``loop.shutdown_asyncgens()``
|
||||
before closing — otherwise we leak event-loop bound resources
|
||||
that re-emerge as ``RuntimeError: Event loop is closed`` later.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _agen():
|
||||
try:
|
||||
yield 1
|
||||
yield 2
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
# Iterate the agen partially, then leave it dangling — exactly
|
||||
# the situation shutdown_asyncgens() is designed to clean up.
|
||||
async for v in _agen():
|
||||
if v == 1:
|
||||
break
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# By the time the helper returns, garbage collection + shutdown_asyncgens
|
||||
# should have ensured no live async-gen references remain. We don't
|
||||
# assert agen.closed directly (it depends on GC ordering); the real
|
||||
# contract is "no warnings, no event-loop-closed errors". A successful
|
||||
# second invocation proves the loop was cleaned up properly.
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# Force a GC pass to surface any 'coroutine was never awaited'
|
||||
# warnings that would indicate the cleanup is broken.
|
||||
gc.collect()
|
||||
|
||||
|
||||
def test_runner_uses_proactor_loop_on_windows() -> None:
|
||||
"""On Windows the celery worker preselects a Proactor policy so
|
||||
subprocess (ffmpeg) calls work. The helper must not silently fall
|
||||
back to a Selector loop and re-break video/podcast generation.
|
||||
"""
|
||||
if not sys.platform.startswith("win"):
|
||||
pytest.skip("Windows-specific event-loop policy assertion")
|
||||
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
# Mirror the policy set at the top of every Windows celery task.
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
observed: list[str] = []
|
||||
|
||||
async def _body() -> None:
|
||||
observed.append(type(asyncio.get_running_loop()).__name__)
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert observed == ["ProactorEventLoop"]
|
||||
|
|
@ -113,6 +113,19 @@ async def _denying_billable_call(**kwargs):
|
|||
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -187,8 +200,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
|
|||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||
)
|
||||
assert call["thread_id"] == 99
|
||||
assert call["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"podcast_id": 7,
|
||||
"title": "Test Podcast",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -279,6 +299,49 @@ async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypat
|
|||
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=10)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "billable_call", _settlement_failing_billable_call
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=10,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 10,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||
|
|
|
|||
|
|
@ -0,0 +1,119 @@
|
|||
"""Predicate-level test for the chat streaming safety net.
|
||||
|
||||
The safety net in ``stream_new_chat`` rejects an image turn early with
|
||||
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
|
||||
selected model is *known* to be text-only. The earlier round of this
|
||||
work used a strict opt-in flag (``supports_image_input`` defaulting to
|
||||
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
|
||||
deployments — this is the regression we're fixing.
|
||||
|
||||
The new predicate is :func:`is_known_text_only_chat_model`, which
|
||||
returns True only when LiteLLM's authoritative model map *explicitly*
|
||||
sets ``supports_vision=False``. Anything else (vision True, missing
|
||||
key, exception) returns False so the request flows through to the
|
||||
provider.
|
||||
|
||||
We exercise the predicate directly here rather than driving the full
|
||||
``stream_new_chat`` generator — covering the gate in isolation keeps
|
||||
the test focused on the regression while the generator's wider behavior
|
||||
is exercised by the integration suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
||||
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
|
||||
vision-capable per LiteLLM's model map. The previous round's
|
||||
blanket-False default blocked it; the new predicate must NOT mark
|
||||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_unknown_model():
|
||||
"""Default-pass on unknown — the safety net only blocks definitive
|
||||
text-only confirmations. A freshly added third-party model that
|
||||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
||||
"""Transient ``litellm.get_model_info`` exception ≠ block. The
|
||||
helper swallows the error and treats it as 'unknown' → False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM to assert the only path that returns True is the
|
||||
explicit ``supports_vision=False`` case. Anything else (True,
|
||||
None, missing key) returns False from the predicate."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info_explicit_false(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def _info_true(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def _info_missing(**_kwargs):
|
||||
return {"max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -105,6 +105,19 @@ async def _denying_billable_call(**kwargs):
|
|||
yield SimpleNamespace() # pragma: no cover
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -176,11 +189,15 @@ async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeyp
|
|||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
|
||||
)
|
||||
assert call["thread_id"] == 99
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"video_presentation_id": 11,
|
||||
"title": "Test Presentation",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -280,6 +297,57 @@ async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch
|
|||
assert graph_invoked == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=14)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"billable_call",
|
||||
_settlement_failing_billable_call,
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=14,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 14,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue