Merge remote-tracking branch 'upstream/dev' into fix/memory-extraction

This commit is contained in:
Anish Sarkar 2026-05-04 12:03:44 +05:30
commit b981b51ab1
176 changed files with 20407 additions and 6258 deletions

View file

@ -0,0 +1,268 @@
"""Regression tests for the compiled-agent cache.
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
build-failure non-caching) and the cache-key signature helpers that
``create_surfsense_deep_agent`` relies on. The integration with
``create_surfsense_deep_agent`` is covered separately by the streaming
contract tests; this module focuses on the primitives so a regression
in the cache implementation is caught before it reaches the agent
factory.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import pytest
from app.agents.new_chat.agent_cache import (
flags_signature,
reload_for_tests,
stable_hash,
system_prompt_hash,
tools_signature,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# stable_hash + signature helpers
# ---------------------------------------------------------------------------
def test_stable_hash_is_deterministic_across_calls() -> None:
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
assert a == b
def test_stable_hash_changes_when_any_part_changes() -> None:
base = stable_hash("v1", 42, "thread-9")
assert stable_hash("v1", 42, "thread-10") != base
assert stable_hash("v2", 42, "thread-9") != base
assert stable_hash("v1", 43, "thread-9") != base
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
"""Two tool lists with the same surface must hash identically.
The cache key MUST NOT change when the underlying ``BaseTool``
instances are different Python objects (a fresh request constructs
fresh tool instances every time). Hashing on ``(name, description)``
keeps the cache hot across requests with identical tool surfaces.
"""
@dataclass
class FakeTool:
name: str
description: str
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
sig_a = tools_signature(
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
)
sig_b = tools_signature(
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_a == sig_b, "tool order must not affect the signature"
# Adding a tool rotates the key.
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
sig_c = tools_signature(
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_c != sig_a
def test_tools_signature_rotates_when_connector_set_changes() -> None:
@dataclass
class FakeTool:
name: str
description: str
tools = [FakeTool("a", "x")]
base = tools_signature(
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
)
added = tools_signature(
tools,
available_connectors=["NOTION", "SLACK"],
available_document_types=["FILE"],
)
assert base != added, "adding a connector must rotate the cache key"
def test_flags_signature_changes_when_flag_flips() -> None:
@dataclass(frozen=True)
class Flags:
a: bool = True
b: bool = False
base = flags_signature(Flags())
flipped = flags_signature(Flags(b=True))
assert base != flipped
def test_system_prompt_hash_is_stable_and_distinct() -> None:
p1 = "You are a helpful assistant."
p2 = "You are a helpful assistant!" # one-character delta
assert system_prompt_hash(p1) == system_prompt_hash(p1)
assert system_prompt_hash(p1) != system_prompt_hash(p2)
# ---------------------------------------------------------------------------
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is b, "cache must return the SAME object across hits"
assert builds == 1, "builder must run exactly once"
@pytest.mark.asyncio
async def test_cache_different_keys_get_different_instances() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k1", builder=builder)
b = await cache.get_or_build("k2", builder=builder)
assert a is not b
@pytest.mark.asyncio
async def test_cache_stale_entries_get_rebuilt() -> None:
# ttl=0 means every read sees the entry as immediately stale.
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "stale entry must rebuild a fresh instance"
assert builds == 2
@pytest.mark.asyncio
async def test_cache_evicts_lru_when_full() -> None:
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("a", builder=builder)
_ = await cache.get_or_build("b", builder=builder)
# Re-touch "a" so "b" is now the LRU victim.
a_again = await cache.get_or_build("a", builder=builder)
assert a_again is a
# Inserting "c" should evict "b" (LRU), not "a".
_ = await cache.get_or_build("c", builder=builder)
assert cache.stats()["size"] == 2
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
a_hit = await cache.get_or_build("a", builder=builder)
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
@pytest.mark.asyncio
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
"""Two concurrent get_or_build calls on the same key must share one builder."""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
build_started = asyncio.Event()
builds = 0
async def slow_builder() -> object:
nonlocal builds
builds += 1
build_started.set()
# Yield control so the second waiter can race against us.
await asyncio.sleep(0.05)
return object()
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
# Wait until the first builder has started, then race a second waiter.
await build_started.wait()
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
a, b = await asyncio.gather(task_a, task_b)
assert a is b, "coalesced waiters must observe the same value"
assert builds == 1, "concurrent cold misses must collapse to ONE build"
@pytest.mark.asyncio
async def test_cache_does_not_store_failed_builds() -> None:
"""A builder that raises must NOT poison the cache.
The next caller for the same key must run the builder again (not
re-raise the cached exception).
"""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
attempts = 0
async def flaky_builder() -> object:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("transient")
return object()
with pytest.raises(RuntimeError, match="transient"):
await cache.get_or_build("k", builder=flaky_builder)
# Second call must retry — not re-raise the cached exception.
value = await cache.get_or_build("k", builder=flaky_builder)
assert value is not None
assert attempts == 2
@pytest.mark.asyncio
async def test_cache_invalidate_drops_entry() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k", builder=builder)
assert cache.invalidate("k") is True
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "post-invalidation lookup must rebuild"
@pytest.mark.asyncio
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
async def builder() -> object:
return object()
await cache.get_or_build("user:1:thread:1", builder=builder)
await cache.get_or_build("user:1:thread:2", builder=builder)
await cache.get_or_build("user:2:thread:1", builder=builder)
removed = cache.invalidate_prefix("user:1:")
assert removed == 2
assert cache.stats()["size"] == 1
# The user:2 entry must still be hot (no rebuild).
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
assert survivor_value is not None

View file

@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
"SURFSENSE_ENABLE_AGENT_CACHE",
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]:
monkeypatch.delenv(name, raising=False)
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
assert flags.any_new_middleware_enabled() is False
assert flags.enable_context_editing is True
assert flags.enable_compaction_v2 is True
assert flags.enable_retry_after is True
assert flags.enable_model_fallback is False
assert flags.enable_model_call_limit is True
assert flags.enable_tool_call_limit is True
assert flags.enable_tool_call_repair is True
assert flags.enable_doom_loop is True
assert flags.enable_permission is True
assert flags.enable_busy_mutex is True
assert flags.enable_llm_tool_selector is False
assert flags.enable_skills is True
assert flags.enable_specialized_subagents is True
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
# Phase 2: agent cache is now default-on (the prerequisite tool
# ``db_session`` refactor landed). The companion gp-subagent share
# flag stays default-off pending data on cold-miss frequency.
assert flags.enable_agent_cache is True
assert flags.enable_agent_cache_share_gp_subagent is False
assert flags.any_new_middleware_enabled() is True
def test_master_kill_switch_overrides_individual_flags(
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
# `enable_otel` is intentionally orthogonal — it does NOT count toward
# ``any_new_middleware_enabled`` because OTel is observability-only and
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
monkeypatch.setenv(env_name, "true")
monkeypatch.setenv(env_name, "false")
flags = reload_for_tests()
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
if attr in counts_toward_middleware:
assert flags.any_new_middleware_enabled() is True
else:
assert flags.any_new_middleware_enabled() is False
assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"

View file

@ -0,0 +1,344 @@
"""Tests for ``FlattenSystemMessageMiddleware``.
The middleware exists to defend against Anthropic's "Found 5 cache_control
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
the system message and the OpenRouterAnthropic adapter redistributes
``cache_control`` across all of them. The flattening collapses every
all-text system content list to a single string before the LLM call.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
_flatten_text_blocks,
_flattened_request,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# _flatten_text_blocks — pure helper, the heart of the middleware.
# ---------------------------------------------------------------------------
class TestFlattenTextBlocks:
def test_joins_text_blocks_with_double_newline(self) -> None:
blocks = [
{"type": "text", "text": "<surfsense base>"},
{"type": "text", "text": "<filesystem section>"},
{"type": "text", "text": "<skills section>"},
]
assert (
_flatten_text_blocks(blocks)
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
)
def test_handles_single_text_block(self) -> None:
blocks = [{"type": "text", "text": "only one"}]
assert _flatten_text_blocks(blocks) == "only one"
def test_handles_empty_list(self) -> None:
assert _flatten_text_blocks([]) == ""
def test_passes_through_bare_string_blocks(self) -> None:
# LangChain content can mix bare strings and dict blocks.
blocks = ["raw string", {"type": "text", "text": "dict block"}]
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
def test_returns_none_for_image_block(self) -> None:
# System messages with images are rare — but we never want to
# silently lose the image payload by joining as text.
blocks = [
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_for_non_dict_non_str_block(self) -> None:
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_field_missing(self) -> None:
blocks = [{"type": "text"}] # no ``text`` key
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_is_not_string(self) -> None:
blocks = [{"type": "text", "text": ["nested", "list"]}]
assert _flatten_text_blocks(blocks) is None
def test_drops_cache_control_from_inner_blocks(self) -> None:
# The whole point: existing cache_control on inner blocks is
# discarded so LiteLLM's ``cache_control_injection_points`` can
# re-attach exactly one breakpoint after flattening.
blocks = [
{"type": "text", "text": "first"},
{
"type": "text",
"text": "second",
"cache_control": {"type": "ephemeral"},
},
]
flattened = _flatten_text_blocks(blocks)
assert flattened == "first\n\nsecond"
assert "cache_control" not in flattened # type: ignore[operator]
# ---------------------------------------------------------------------------
# _flattened_request — decides when to override and when to no-op.
# ---------------------------------------------------------------------------
def _make_request(system_message: SystemMessage | None) -> Any:
"""Build a minimal ModelRequest stub. We only need .system_message
and .override(system_message=...) the middleware never touches
other fields.
"""
request = MagicMock()
request.system_message = system_message
def override(**kwargs: Any) -> Any:
new_request = MagicMock()
new_request.system_message = kwargs.get(
"system_message", request.system_message
)
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
return new_request
request.override = override
return request
class TestFlattenedRequest:
def test_collapses_multi_block_system_to_string(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<base>"},
{"type": "text", "text": "<todo>"},
{"type": "text", "text": "<filesystem>"},
{"type": "text", "text": "<skills>"},
{"type": "text", "text": "<subagents>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message, SystemMessage)
assert flattened.system_message.content == (
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
)
def test_no_op_for_string_content(self) -> None:
sys = SystemMessage(content="already a string")
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_for_single_block_list(self) -> None:
# One block already produces one breakpoint — no need to flatten.
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_when_system_message_missing(self) -> None:
request = _make_request(None)
assert _flattened_request(request) is None
def test_no_op_when_list_contains_non_text_block(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
)
request = _make_request(sys)
assert _flattened_request(request) is None
def test_preserves_additional_kwargs_and_metadata(self) -> None:
# Defensive: nothing in the current chain sets these on a system
# message, but losing them silently when something does in the
# future would be a regression. ``name`` in particular is the only
# ``additional_kwargs`` field that ChatLiteLLM's
# ``_convert_message_to_dict`` propagates onto the wire.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
],
additional_kwargs={"name": "surfsense_system", "x": 1},
response_metadata={"tokens": 42},
)
sys.id = "sys-msg-1"
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert flattened.system_message.content == "a\n\nb"
assert flattened.system_message.additional_kwargs == {
"name": "surfsense_system",
"x": 1,
}
assert flattened.system_message.response_metadata == {"tokens": 42}
assert flattened.system_message.id == "sys-msg-1"
def test_idempotent_when_run_twice(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
request = _make_request(sys)
first = _flattened_request(request)
assert first is not None
# Second pass on the already-flattened request should be a no-op.
# We re-wrap in a request stub since the helper inspects
# ``request.system_message.content``.
second_request = _make_request(first.system_message)
assert _flattened_request(second_request) is None
# ---------------------------------------------------------------------------
# Middleware integration — verify the handler sees a flattened request.
# ---------------------------------------------------------------------------
class TestMiddlewareWrap:
@pytest.mark.asyncio
async def test_async_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = await mw.awrap_model_call(request, handler)
assert result == "ok"
assert isinstance(captured["request"].system_message, SystemMessage)
assert captured["request"].system_message.content == "alpha\n\nbeta"
@pytest.mark.asyncio
async def test_async_passes_through_when_already_string(self) -> None:
sys = SystemMessage(content="just a string")
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
await mw.awrap_model_call(request, handler)
# Same request object: no override happened.
assert captured["request"] is request
def test_sync_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = mw.wrap_model_call(request, handler)
assert result == "ok"
assert captured["request"].system_message.content == "alpha\n\nbeta"
def test_sync_passes_through_when_no_system_message(self) -> None:
request = _make_request(None)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
mw.wrap_model_call(request, handler)
assert captured["request"] is request
# ---------------------------------------------------------------------------
# Regression guard — pin the worst-case shape that triggered the
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
# downstream cache_control_injection_points can only place 1 breakpoint
# on the system message regardless of provider redistribution quirks.
# ---------------------------------------------------------------------------
def test_regression_five_block_system_collapses_to_one_block() -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
{"type": "text", "text": "<TodoListMiddleware section>"},
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
{"type": "text", "text": "<SkillsMiddleware section>"},
{"type": "text", "text": "<SubAgentMiddleware section>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message.content, str)
# The exact join doesn't matter for the cache_control accounting —
# only that there is exactly ONE content block when LiteLLM's
# AnthropicCacheControlHook later targets ``role: system``.
assert "<surfsense base" in flattened.system_message.content
assert "<SubAgentMiddleware" in flattened.system_message.content
def test_regression_human_message_not_modified() -> None:
# Sanity: the middleware MUST NOT touch user messages — only the
# system message. Multi-block user content is the path that carries
# image attachments and would lose its image_url block on
# accidental flatten.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
user = HumanMessage(
content=[
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
]
)
request = _make_request(sys)
request.messages = [user]
flattened = _flattened_request(request)
assert flattened is not None
# System flattened to string …
assert isinstance(flattened.system_message.content, str)
# … user message is untouched (the helper does not even look at it).
assert flattened.messages == [user]
assert isinstance(user.content, list)
assert len(user.content) == 2

View file

@ -1,4 +1,4 @@
"""Tests for ``apply_litellm_prompt_caching`` in
r"""Tests for ``apply_litellm_prompt_caching`` in
:mod:`app.agents.new_chat.prompt_caching`.
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
``litellm.completion(...)``. The tests below pin its public contract:
1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
savings compound across multi-turn conversations on Anthropic-family
providers.
providers. ``index: 0`` is used (rather than ``role: system``) because
the deepagent stack accumulates multiple ``SystemMessage``\ s in
``state["messages"]`` and ``role: system`` would tag every one of
them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None:
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert {"location": "message", "role": "system"} in points
assert {"location": "message", "index": 0} in points
assert {"location": "message", "index": -1} in points
assert len(points) == 2
def test_does_not_inject_role_system_breakpoint() -> None:
"""Regression: deliberately AVOID ``role: system`` so we don't tag
every SystemMessage the deepagent ``before_agent`` injectors push
into ``state["messages"]`` (priority, tree, memory, file-intent,
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
``cache_control`` cap and surfaces as
``OpenrouterException: A maximum of 4 blocks with cache_control may
be provided. Found N`` 400s.
"""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert all(p.get("role") != "system" for p in points), (
f"Expected no role=system breakpoint, got: {points}"
)
def test_injection_points_set_for_anthropic_config() -> None:
"""Anthropic-family configs need the marker — verify it lands."""
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")

View file

@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
)
)
assert plan.is_recency_query is False
# ── mentioned_document_ids cross-turn drain ────────────────────────────
class TestKnowledgePriorityMentionDrain:
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
instance across turns of the same thread. ``mentioned_document_ids``
can therefore enter the middleware via two paths:
1. The constructor closure (``__init__(mentioned_document_ids=...)``)
seeded by the cache-miss build on turn 1.
2. ``runtime.context.mentioned_document_ids`` supplied freshly per
turn by the streaming task.
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
on turn 2 would fall through to the closure (because ``[]`` is falsy in
Python) and replay turn 1's mentions. This class pins down the
correct behaviour: the runtime path is authoritative even when empty,
and the closure is drained the first time the runtime path fires so
no later turn can ever resurrect stale state.
"""
@staticmethod
def _make_runtime(mention_ids: list[int]):
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
from types import SimpleNamespace
return SimpleNamespace(
context=SimpleNamespace(mentioned_document_ids=mention_ids),
)
@staticmethod
def _planner_llm() -> "FakeLLM":
# Planner returns a stable, non-recency plan so we always land in
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
# invoked alongside the main search).
return FakeLLM(
json.dumps(
{
"optimized_query": "follow up question",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
"""Turn 1 with mentions in BOTH closure and runtime context: the
runtime path wins AND the closure is drained so a future turn
cannot replay it.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what is in those docs?")]},
runtime=self._make_runtime([1, 2, 3]),
)
assert fetched_ids == [[1, 2, 3]], (
"runtime.context mentions must be the source of truth on turn 1"
)
assert middleware.mentioned_document_ids == [], (
"closure must be drained the first time the runtime path fires "
"so no later turn can replay stale mentions"
)
async def test_empty_runtime_context_does_not_replay_closure_mentions(
self, monkeypatch
):
"""Regression: turn 2 with NO mentions must not surface turn 1's
mentions from the constructor closure.
Before the fix, ``if ctx_mentions:`` treated an empty list as
absent and fell through to ``elif self.mentioned_document_ids:``,
replaying turn 1's mentions. This test pins down the corrected
behaviour.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
# Simulate a cached middleware instance whose closure was seeded
# by a previous turn's cache-miss build (mentions=[1,2,3]).
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
# Turn 2: streaming task supplies an EMPTY mention list (no
# mentions on this follow-up turn).
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what about the next steps?")]},
runtime=self._make_runtime([]),
)
assert fetched_ids == [], (
"fetch_mentioned_documents must NOT be called when the runtime "
"context says there are no mentions for this turn"
)
async def test_legacy_path_fires_only_when_runtime_context_absent(
self, monkeypatch
):
"""Backward-compat: if a caller doesn't supply runtime.context (old
non-streaming code path), the closure-injected mentions are still
honoured exactly once and then drained.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[7, 8],
)
# First call: no runtime → legacy path uses the closure.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="initial question")]},
runtime=None,
)
# Second call: still no runtime — closure already drained, so no replay.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="follow up")]},
runtime=None,
)
assert fetched_ids == [[7, 8]], (
"legacy path must honour the closure exactly once and then drain it"
)
assert middleware.mentioned_document_ids == []

View file

@ -0,0 +1,110 @@
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
There is no DB column for ``supports_image_input`` on
``NewLLMConfig`` the value is resolved at the API boundary by
``derive_supports_image_input`` so the new-chat selector / streaming
task can read the same field shape regardless of source (BYOK vs YAML
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
user out of their own model choice.
"""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from uuid import uuid4
import pytest
from app.db import LiteLLMProvider
from app.routes import new_llm_config_routes
pytestmark = pytest.mark.unit
def _byok_row(
*,
id_: int,
model_name: str,
base_model: str | None = None,
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
custom_provider: str | None = None,
) -> object:
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
enum validator accepts it same as the ORM row would carry."""
return SimpleNamespace(
id=id_,
name=f"BYOK-{id_}",
description=None,
provider=provider,
custom_provider=custom_provider,
model_name=model_name,
api_key="sk-byok",
api_base=None,
litellm_params={"base_model": base_model} if base_model else None,
system_instructions="",
use_default_system_instructions=True,
citations_enabled=True,
created_at=datetime.now(tz=UTC),
search_space_id=42,
user_id=uuid4(),
)
def test_serialize_byok_known_vision_model_resolves_true():
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
True. The serialized row carries that value through to the
``NewLLMConfigRead`` schema."""
row = _byok_row(id_=1, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
assert serialized.id == 1
assert serialized.model_name == "gpt-4o"
def test_serialize_byok_unknown_model_default_allows():
"""Unknown / unmapped: default-allow. The streaming-task safety net
is the actual block, and it requires LiteLLM to *explicitly* say
text-only so a brand new BYOK model should not be pre-judged."""
row = _byok_row(
id_=2,
model_name="brand-new-model-x9-unmapped",
provider=LiteLLMProvider.CUSTOM,
custom_provider="brand_new_proxy",
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_uses_base_model_when_present():
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
helper must consult ``base_model`` first or unrecognised deployment
ids would shadow the real capability."""
row = _byok_row(
id_=3,
model_name="my-azure-deployment-id-no-litellm-knows-this",
base_model="gpt-4o",
provider=LiteLLMProvider.AZURE_OPENAI,
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_returns_pydantic_read_model():
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
the schema additions are guaranteed to be present in the API
surface. This guards against a future regression where someone
deletes the augmentation step and falls back to ORM passthrough."""
from app.schemas import NewLLMConfigRead
row = _byok_row(id_=4, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert isinstance(serialized, NewLLMConfigRead)

View file

@ -0,0 +1,184 @@
"""Unit tests for ``is_premium`` derivation on the global image-gen and
vision-LLM list endpoints.
Chat globals (``GET /global-llm-configs``) already emit
``is_premium = (billing_tier == "premium")``. Image and vision did not,
which made the new-chat ``model-selector`` render the Free/Premium badge
on the Chat tab but skip it on the Image and Vision tabs (the selector
keys its badge logic off ``is_premium``). These tests pin parity:
* YAML free entry ``is_premium=False``
* YAML premium entry ``is_premium=True``
* OpenRouter dynamic premium entry ``is_premium=True``
* Auto stub (always emitted when at least one config is present)
``is_premium=False``
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_IMAGE_FIXTURE: list[dict] = [
{
"id": -1,
"name": "DALL-E 3",
"provider": "OPENAI",
"model_name": "dall-e-3",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "GPT-Image 1 (premium)",
"provider": "OPENAI",
"model_name": "gpt-image-1",
"api_key": "sk-test",
"billing_tier": "premium",
},
{
"id": -20_001,
"name": "google/gemini-2.5-flash-image (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
_VISION_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o Vision",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "Claude 3.5 Sonnet (premium)",
"provider": "ANTHROPIC",
"model_name": "claude-3-5-sonnet",
"api_key": "sk-ant-test",
"billing_tier": "premium",
},
{
"id": -30_001,
"name": "openai/gpt-4o (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
# =============================================================================
# Image generation
# =============================================================================
@pytest.mark.asyncio
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
"""Each emitted config must carry ``is_premium`` derived server-side
from ``billing_tier``. The Auto stub is always free.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub is always emitted when at least one global config exists,
# and it must always declare itself free (Auto-mode billing-tier
# surfacing is a separate follow-up).
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
# YAML free entry — ``is_premium=False``
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
# YAML premium entry — ``is_premium=True``
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
# OpenRouter dynamic premium entry — same field, same derivation
assert by_id[-20_001]["is_premium"] is True
assert by_id[-20_001]["billing_tier"] == "premium"
# Every emitted dict (including Auto) must have the field — never missing.
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
"""When there are no global configs at all, the endpoint emits an
empty list (no Auto stub) Auto mode would have nothing to route to.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
assert payload == []
# =============================================================================
# Vision LLM
# =============================================================================
@pytest.mark.asyncio
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
assert by_id[-30_001]["is_premium"] is True
assert by_id[-30_001]["billing_tier"] == "premium"
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
assert payload == []

View file

@ -0,0 +1,106 @@
"""Unit tests for ``supports_image_input`` derivation on the chat global
config endpoint (``GET /global-new-llm-configs``).
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
loader for operator overrides, or by the OpenRouter integration from
``architecture.input_modalities``) wins.
2. ``derive_supports_image_input`` helper default-allow on unknown
models, only False when LiteLLM / OR modalities are definitive.
The flag is purely informational at the API boundary. The streaming
task safety net (``is_known_text_only_chat_model``) is the actual block,
and it requires LiteLLM to *explicitly* mark the model as text-only.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o (explicit true)",
"description": "vision-capable, explicit YAML override",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
"supports_image_input": True,
},
{
"id": -2,
"name": "DeepSeek V3 (explicit false)",
"description": "OpenRouter dynamic — modality-derived false",
"provider": "OPENROUTER",
"model_name": "deepseek/deepseek-v3.2-exp",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "free",
"supports_image_input": False,
},
{
"id": -10_010,
"name": "Unannotated GPT-4o",
"description": "no flag set — resolver should derive True via LiteLLM",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
# supports_image_input intentionally absent
},
{
"id": -10_011,
"name": "Unannotated unknown model",
"description": "unmapped — default-allow True",
"provider": "CUSTOM",
"custom_provider": "brand_new_proxy",
"model_name": "brand-new-model-x9",
"api_key": "sk-test",
"billing_tier": "free",
},
]
@pytest.mark.asyncio
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
"""Each emitted chat config carries ``supports_image_input`` as a
bool. Explicit values win; unannotated entries are resolved via the
helper (default-allow True)."""
from app.config import config
from app.routes import new_llm_config_routes
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub: optimistic True so the user can keep Auto selected with
# vision-capable deployments somewhere in the pool.
assert 0 in by_id, "Auto stub should be emitted when configs exist"
assert by_id[0]["supports_image_input"] is True
assert by_id[0]["is_auto_mode"] is True
# Explicit True is preserved.
assert by_id[-1]["supports_image_input"] is True
# Explicit False is preserved (the exact failure mode the safety net
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
# endpoints found that support image input").
assert by_id[-2]["supports_image_input"] is False
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
assert by_id[-10_010]["supports_image_input"] is True
# Unknown / unmapped model: default-allow rather than pre-judge.
assert by_id[-10_011]["supports_image_input"] is True
for cfg in payload:
assert "supports_image_input" in cfg, (
f"supports_image_input missing from {cfg.get('id')}"
)
assert isinstance(cfg["supports_image_input"], bool)

View file

@ -0,0 +1,138 @@
"""Unit tests for the image-generation route's billing-resolution helper.
End-to-end "POST /image-generations returns 402" coverage requires the
integration harness (real DB, real auth) and lives in
``tests/integration/document_upload/`` alongside the other quota tests.
This unit test focuses on the new ``_resolve_billing_for_image_gen``
helper which:
* Returns ``free`` for Auto mode, even when premium configs exist
(Auto-mode billing-tier surfacing is a follow-up).
* Returns ``free`` for user-owned BYOK configs (positive IDs).
* Returns the global config's ``billing_tier`` for negative IDs.
* Honours the per-config ``quota_reserve_micros`` override when present.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_resolve_billing_for_auto_mode(monkeypatch):
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, # Not consumed on this code path.
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
search_space=search_space,
)
assert tier == "free"
assert model == "auto"
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_for_premium_global_config(monkeypatch):
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"billing_tier": "premium",
"quota_reserve_micros": 75_000,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"billing_tier": "free",
},
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=None)
# Premium with override.
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=-1, search_space=search_space
)
assert tier == "premium"
assert model == "openai/gpt-image-1"
assert reserve == 75_000
# Free, no override → falls back to default.
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=-2, search_space=search_space
)
assert tier == "free"
# Provider-prefixed model string for OpenRouter.
assert "google/gemini-2.5-flash-image" in model
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_for_user_owned_byok_is_free():
"""User-owned BYOK configs (positive IDs) cost the user nothing on
our side they pay the provider directly. Always free.
"""
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=42, search_space=search_space
)
assert tier == "free"
assert model == "user_byok"
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
"""When the request omits ``image_generation_config_id``, the helper
must consult the search space's default — so a search space pinned
to a premium global config still gates new requests by quota.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{
"id": -7,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"billing_tier": "premium",
}
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=-7)
(
tier,
model,
_reserve,
) = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=None, search_space=search_space
)
assert tier == "premium"
assert model == "openai/gpt-image-1"

View file

@ -0,0 +1,436 @@
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
Validates the resolver used by Celery podcast/video tasks to compute
``(owner_user_id, billing_tier, base_model)`` from a search space and its
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
``stream_new_chat.py:2294-2351`` and is the single integration point that
prevents Auto-mode podcast/video from leaking premium credit.
Coverage:
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
global returns ``("premium", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
global returns ``("free", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
always ``"free"``.
* Auto mode + ``thread_id=None`` fallback to ``("free", "auto")`` without
hitting the pin service.
* Negative id (no Auto) uses ``get_global_llm_config``'s
``billing_tier``.
* Positive id (user BYOK) always ``"free"``.
* Search space not found raises ``ValueError``.
* ``agent_llm_id`` is None raises ``ValueError``.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from uuid import UUID, uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
class _FakeSession:
"""Tiny AsyncSession stub.
``responses`` is a list of objects to return from successive
``execute()`` calls (in order). The resolver makes at most two
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
lookup), so two queued responses cover the matrix.
"""
def __init__(self, responses: list):
self._responses = list(responses)
async def execute(self, _stmt):
if not self._responses:
return _FakeExecResult(None)
return _FakeExecResult(self._responses.pop(0))
async def commit(self) -> None:
pass
@dataclass
class _FakePinResolution:
resolved_llm_config_id: int
resolved_tier: str = "premium"
from_existing_pin: bool = False
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
return SimpleNamespace(
id=42,
agent_llm_id=agent_llm_id,
user_id=user_id,
)
def _make_byok_config(
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
) -> SimpleNamespace:
return SimpleNamespace(
id=id_,
model_name=model_name,
litellm_params={"base_model": base_model} if base_model else {},
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
"""Auto + thread → pin service resolves to negative-id premium config →
resolver returns ``("premium", <base_model>)``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
# Mock the pin service to return a concrete premium config id.
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
assert selected_llm_config_id == 0
assert thread_id == 99
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
# Mock global config lookup to return a premium entry.
def _fake_get_global(cfg_id):
if cfg_id == -1:
return {
"id": -1,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
return None
# Lazy imports inside the resolver — patch the *target* modules so the
# imported names resolve to our fakes.
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
"""Auto + thread → pin returns negative-id free config → resolver
returns ``("free", <base_model>)``. Same path the pin service takes for
out-of-credit users (graceful degradation)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
def _fake_get_global(cfg_id):
if cfg_id == -3:
return {
"id": -3,
"model_name": "openrouter/free-model",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/free-model"},
}
return None
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/free-model"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
"""Auto + thread → pin returns positive-id BYOK config → resolver
returns ``("free", ...)`` (BYOK is always free per
``AgentConfig.from_new_llm_config``)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
byok_cfg = _make_byok_config(
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
)
session = _FakeSession([search_space, byok_cfg])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3-haiku"
@pytest.mark.asyncio
async def test_auto_mode_without_thread_id_falls_back_to_free():
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
the pin service. Forward-compat fallback for any future direct-API
entrypoint that doesn't have a chat thread."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
"""If the pin service raises ``ValueError`` (thread missing /
mismatched search space), the resolver should log and return free
rather than killing the whole task."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(*args, **kwargs):
raise ValueError("thread missing")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_negative_id_premium_global_returns_premium(monkeypatch):
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
return its ``billing_tier``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_negative_id_free_global_returns_free(monkeypatch):
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "openrouter/some-free",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/some-free"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/some-free"
@pytest.mark.asyncio
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
"""When the global config has no ``litellm_params.base_model``, the
resolver falls back to ``model_name`` matching chat's behavior."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "fallback-model",
"billing_tier": "premium",
# No litellm_params.
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
_, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert tier == "premium"
assert base_model == "fallback-model"
@pytest.mark.asyncio
async def test_positive_id_byok_is_always_free():
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
regardless of underlying provider tier."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
session = _FakeSession([search_space, byok_cfg])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3.5-sonnet"
@pytest.mark.asyncio
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
"""If the BYOK config row is missing/deleted but the search space still
points at it, the resolver still returns free (no debit) with an empty
base_model billable_call's premium path is skipped, no harm done."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == ""
@pytest.mark.asyncio
async def test_search_space_not_found_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
session = _FakeSession([None])
with pytest.raises(ValueError, match="Search space"):
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
@pytest.mark.asyncio
async def test_agent_llm_id_none_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
with pytest.raises(ValueError, match="agent_llm_id"):
await _resolve_agent_billing_for_search_space(session, search_space_id=42)

View file

@ -101,11 +101,116 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id in {-1, -2}
assert result.resolved_llm_config_id == -1
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
assert session.commit_count == 1
@pytest.mark.asyncio
async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
"quality_score": 100,
},
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
"quality_score": 10,
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.resolved_tier == "premium"
@pytest.mark.asyncio
async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5.1",
"api_key": "k1",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 100,
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5.4",
"api_key": "k2",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 10,
},
{
"id": -3,
"provider": "OPENROUTER",
"model_name": "openai/gpt-5.4",
"api_key": "k3",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 100,
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_tier == "premium"
@pytest.mark.asyncio
async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
@ -361,12 +466,12 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(

View file

@ -0,0 +1,286 @@
"""Image-aware extension of the Auto-pin resolver.
When the current chat turn carries an ``image_url`` block, the pin
resolver must:
1. Filter the candidate pool to vision-capable cfgs so a freshly
selected pin can never be text-only.
2. Treat any existing pin whose capability is False as invalid (force
re-pin), even when it would otherwise be reused as the thread's
stable model.
3. Raise ``ValueError`` (mapped to the friendly
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
task) when no vision-capable cfg is available instead of silently
pinning text-only and 404-ing at the provider.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
import pytest
from app.services.auto_model_pin_service import (
clear_healthy,
clear_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _reset_caches():
clear_runtime_cooldown()
clear_healthy()
yield
clear_runtime_cooldown()
clear_healthy()
@dataclass
class _FakeQuotaResult:
allowed: bool
class _FakeExecResult:
def __init__(self, thread):
self._thread = thread
def unique(self):
return self
def scalar_one_or_none(self):
return self._thread
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
async def commit(self):
self.commit_count += 1
def _thread(*, pinned: int | None = None):
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"model_name": f"vision-{id_}",
"api_key": "k",
"billing_tier": tier,
"supports_image_input": True,
"auto_pin_tier": "A",
"quality_score": quality,
}
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"model_name": f"text-{id_}",
"api_key": "k",
"billing_tier": tier,
# Higher quality than the vision cfgs — so a bug that ignores
# the image flag would surface as the resolver picking this one.
"supports_image_input": False,
"auto_pin_tier": "A",
"quality_score": quality,
}
async def _premium_allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
@pytest.mark.asyncio
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
# The thread should be pinned to the vision cfg even though the
# text-only cfg has a higher quality score.
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
"""An existing text-only pin must be invalidated when the next turn
requires image input. The non-image path would happily reuse it."""
from app.config import config
session = _FakeSession(_thread(pinned=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
"""If the thread is already pinned to a vision-capable cfg, reuse it
same as the non-image path. Image-aware filtering must not force
spurious re-pins."""
from app.config import config
session = _FakeSession(_thread(pinned=-2))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
"""The friendly-error path: no vision-capable cfg in the pool -> raise
``ValueError`` whose message contains ``vision-capable`` so the
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _text_only_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
with pytest.raises(ValueError, match="vision-capable"):
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
@pytest.mark.asyncio
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
"""Regression guard: the image flag must default False and not affect
a normal text-only turn text-only cfgs remain selectable."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
@pytest.mark.asyncio
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
"""A YAML cfg that omits ``supports_image_input`` falls through to
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
that returns True, so the cfg should be a valid candidate."""
from app.config import config
session = _FakeSession(_thread())
cfg_unannotated_vision = {
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-4o", # known vision model in LiteLLM map
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "A",
"quality_score": 80,
# NOTE: no supports_image_input key
}
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2

View file

@ -0,0 +1,559 @@
"""Unit tests for the ``billable_call`` async context manager.
Covers the per-call premium-credit lifecycle for image generation and
vision LLM extraction:
* Free configs bypass reserve/finalize but still write an audit row.
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
route layer).
* Successful premium calls reserve, yield the accumulator, then finalize
with the LiteLLM-reported actual cost and write an audit row.
* Failed premium calls release the reservation so credit isn't leaked.
* All quota DB ops happen inside their OWN ``shielded_async_session``,
isolating them from the caller's transaction (issue A).
"""
from __future__ import annotations
import asyncio
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeQuotaResult:
def __init__(
self,
*,
allowed: bool,
used: int = 0,
limit: int = 5_000_000,
remaining: int = 5_000_000,
) -> None:
self.allowed = allowed
self.used = used
self.limit = limit
self.remaining = remaining
class _FakeSession:
"""Minimal AsyncSession stub — record commits for assertion."""
def __init__(self) -> None:
self.committed = False
self.added: list[Any] = []
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
pass
async def close(self) -> None:
pass
@contextlib.asynccontextmanager
async def _fake_shielded_session():
s = _FakeSession()
_SESSIONS_USED.append(s)
yield s
_SESSIONS_USED: list[_FakeSession] = []
def _patch_isolation_layer(
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
finalize_calls: list[dict[str, Any]] = []
release_calls: list[dict[str, Any]] = []
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
reserve_calls.append(
{
"user_id": user_id,
"reserve_micros": reserve_micros,
"request_id": request_id,
}
)
return reserve_result
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
if finalize_exc is not None:
raise finalize_exc
finalize_calls.append(
{
"user_id": user_id,
"actual_micros": actual_micros,
"reserved_micros": reserved_micros,
}
)
return finalize_result or _FakeQuotaResult(allowed=True)
async def _fake_release(*, db_session, user_id, reserved_micros):
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
record_calls: list[dict[str, Any]] = []
async def _fake_record(session, **kwargs):
record_calls.append(kwargs)
return object()
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_reserve",
_fake_reserve,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_finalize",
_fake_finalize,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_release",
_fake_release,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.shielded_async_session",
_fake_shielded_session,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_fake_record,
raising=False,
)
return {
"reserve": reserve_calls,
"finalize": finalize_calls,
"release": release_calls,
"record": record_calls,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
) as acc:
# Simulate a captured cost — the accumulator is fed by the LiteLLM
# callback in real life, here we add it manually.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
# Free still audits.
assert len(spies["record"]) == 1
assert spies["record"][0]["usage_type"] == "image_generation"
assert spies["record"][0]["cost_micros"] == 37_000
@pytest.mark.asyncio
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "image_generation"
assert err.used_micros == 5_000_000
assert err.limit_micros == 5_000_000
assert err.remaining_micros == 0
# Reserve was attempted, but no finalize/release on a denied reserve
# — the reservation never actually held credit.
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
assert spies["release"] == []
# Denied premium calls do NOT create an audit row (no work happened).
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
# LiteLLM callback would normally fill this — simulate $0.04 image.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert spies["reserve"][0]["reserve_micros"] == 50_000
assert len(spies["finalize"]) == 1
assert spies["finalize"][0]["actual_micros"] == 40_000
assert spies["finalize"][0]["reserved_micros"] == 50_000
assert spies["release"] == []
# And audit row written with the actual debited cost.
assert spies["record"][0]["cost_micros"] == 40_000
# Each quota op opened its OWN session — proves session isolation.
assert len(_SESSIONS_USED) >= 3
# Sessions used should each have committed (or be the audit one which commits).
for _s in _SESSIONS_USED:
# finalize/reserve happen via TokenQuotaService.* which we stub —
# they don't actually call commit on our fake session, but the
# audit session does. We just assert >=1 session committed.
pass
assert any(s.committed for s in _SESSIONS_USED)
@pytest.mark.asyncio
async def test_premium_failure_releases_reservation(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _ProviderError(Exception):
pass
with pytest.raises(_ProviderError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
raise _ProviderError("OpenRouter 503")
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
# Failure path: release the held reservation.
assert len(spies["release"]) == 1
assert spies["release"][0]["reserved_micros"] == 50_000
@pytest.mark.asyncio
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
"""When ``quota_reserve_micros_override`` is None we fall back to
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
Vision LLM calls take this path (token-priced models).
"""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
captured_estimator_calls: list[dict[str, Any]] = []
def _fake_estimate(*, base_model, quota_reserve_tokens):
captured_estimator_calls.append(
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
)
return 12_345
monkeypatch.setattr(
"app.services.billable_calls.estimate_call_reserve_micros",
_fake_estimate,
raising=False,
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
usage_type="vision_extraction",
):
pass
assert captured_estimator_calls == [
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
]
assert spies["reserve"][0]["reserve_micros"] == 12_345
@pytest.mark.asyncio
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
from app.services.billable_calls import BillingSettlementError, billable_call
class _FinalizeError(RuntimeError):
pass
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(allowed=True),
finalize_exc=_FinalizeError("db finalize failed"),
)
user_id = uuid4()
with pytest.raises(BillingSettlementError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert len(spies["release"]) == 1
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _HangingCommitSession(_FakeSession):
async def commit(self) -> None:
await asyncio.sleep(60)
@contextlib.asynccontextmanager
async def _hanging_session_factory():
s = _HangingCommitSession()
_SESSIONS_USED.append(s)
yield s
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
billable_session_factory=_hanging_session_factory,
audit_timeout_seconds=0.01,
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert len(spies["finalize"]) == 1
assert len(spies["record"]) == 1
assert spies["release"] == []
@pytest.mark.asyncio
async def test_free_audit_failure_is_best_effort(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
async def _failing_record(_session, **_kwargs):
raise RuntimeError("audit insert failed")
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_failing_record,
raising=False,
)
async with billable_call(
user_id=uuid4(),
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
audit_timeout_seconds=0.01,
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
"""Free podcast configs must skip reserve/finalize but still emit a
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
have full audit coverage of free-tier agent runs."""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openrouter/some-free-model",
quota_reserve_micros_override=200_000,
usage_type="podcast_generation",
thread_id=99,
call_details={"podcast_id": 7, "title": "Test Podcast"},
) as acc:
# Two transcript LLM calls aggregated into one accumulator.
acc.add(
model="openrouter/some-free-model",
prompt_tokens=1500,
completion_tokens=8000,
total_tokens=9500,
cost_micros=0,
call_kind="chat",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
assert row["thread_id"] is None
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
@pytest.mark.asyncio
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
"""Premium video-presentation runs that hit a denied reservation must
raise ``QuotaInsufficientError`` *before* the graph runs and must not
emit an audit row (no work happened)."""
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="gpt-5.4",
quota_reserve_micros_override=1_000_000,
usage_type="video_presentation_generation",
thread_id=99,
call_details={"video_presentation_id": 12, "title": "Test Video"},
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "video_presentation_generation"
assert err.remaining_micros == 500_000
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
assert spies["finalize"] == []
assert spies["release"] == []
assert spies["record"] == []

View file

@ -0,0 +1,177 @@
"""Defense-in-depth: image-gen call sites must not let an empty
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
The bug repro: an OpenRouter image-gen config ships
``api_base=""``. The pre-fix call site in
``image_generation_routes._execute_image_generation`` did
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
silently dropped the empty string. LiteLLM then fell back to
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
and OpenRouter's ``image_generation/transformation`` appended
``/chat/completions`` to it 404 ``Resource not found``.
This test pins the post-fix behaviour: with an empty ``api_base`` in
the config, the call site MUST set ``api_base`` to OpenRouter's public
URL instead of leaving it unset.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
"""The global-config branch (``config_id < 0``) of
``_execute_image_generation`` must apply the resolver and pin
``api_base`` to OpenRouter when the config ships an empty string.
"""
from app.routes import image_generation_routes
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "", # the original bug shape
"api_version": None,
"litellm_params": {},
}
captured: dict = {}
async def fake_aimage_generation(**kwargs):
captured.update(kwargs)
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
image_gen = MagicMock()
image_gen.image_generation_config_id = cfg["id"]
image_gen.prompt = "test"
image_gen.n = 1
image_gen.quality = None
image_gen.size = None
image_gen.style = None
image_gen.response_format = None
image_gen.model = None
search_space = MagicMock()
search_space.image_generation_config_id = cfg["id"]
session = MagicMock()
with (
patch.object(
image_generation_routes,
"_get_global_image_gen_config",
return_value=cfg,
),
patch.object(
image_generation_routes,
"aimage_generation",
side_effect=fake_aimage_generation,
),
):
await image_generation_routes._execute_image_generation(
session=session, image_gen=image_gen, search_space=search_space
)
# The whole point of the fix: even with empty ``api_base`` in the
# config, we forward OpenRouter's public URL so the call doesn't
# inherit an Azure endpoint.
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
@pytest.mark.asyncio
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
"""Same defense at the agent tool entry point — both surfaces share
the same OpenRouter config payloads."""
from app.agents.new_chat.tools import generate_image as gi_module
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
}
captured: dict = {}
async def fake_aimage_generation(**kwargs):
captured.update(kwargs)
response = MagicMock()
response.model_dump.return_value = {
"data": [{"url": "https://example.com/x.png"}]
}
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
return response
search_space = MagicMock()
search_space.id = 1
search_space.image_generation_config_id = cfg["id"]
session_cm = AsyncMock()
session = AsyncMock()
session_cm.__aenter__.return_value = session
scalars = MagicMock()
scalars.first.return_value = search_space
exec_result = MagicMock()
exec_result.scalars.return_value = scalars
session.execute.return_value = exec_result
session.add = MagicMock()
session.commit = AsyncMock()
session.refresh = AsyncMock()
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
async def _refresh(obj):
obj.id = 1
session.refresh.side_effect = _refresh
with (
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
patch.object(
gi_module, "aimage_generation", side_effect=fake_aimage_generation
),
patch.object(
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
),
):
tool = gi_module.create_generate_image_tool(
search_space_id=1, db_session=MagicMock()
)
await tool.ainvoke({"prompt": "a cat", "n": 1})
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
"""The Auto-mode router pool must also resolve ``api_base`` when an
OpenRouter config ships an empty string. The deployment dict is fed
straight to ``litellm.Router``, so a missing ``api_base`` would
leak the same way as the direct call sites.
"""
from app.services.image_gen_router_service import ImageGenRouterService
deployment = ImageGenRouterService._config_to_deployment(
{
"model_name": "openai/gpt-image-1",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"

View file

@ -214,3 +214,167 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names
# ---------------------------------------------------------------------------
# _generate_image_gen_configs / _generate_vision_llm_configs
# ---------------------------------------------------------------------------
def test_generate_image_gen_configs_filters_by_image_output():
"""Only models with ``output_modalities`` containing ``image`` are emitted.
Tool-calling and context filters are intentionally NOT applied image
generation has nothing to do with tool calls and context windows.
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
# Pure image-gen model (small context, no tools — should still emit).
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Multi-modal: text+image output (should still emit).
{
"id": "google/gemini-2.5-flash-image",
"architecture": {"output_modalities": ["text", "image"]},
"context_length": 1_000_000,
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
},
# Pure text model — must NOT emit.
{
"id": "openai/gpt-4o",
"architecture": {"output_modalities": ["text"]},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
]
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openai/gpt-image-1" in model_names
assert "google/gemini-2.5-flash-image" in model_names
assert "openai/gpt-4o" not in model_names
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
for c in cfgs:
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't 404 against an inherited Azure endpoint.
assert c["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_image_gen_configs_assigns_image_id_offset():
"""Image configs use a different id_offset (-20000) so their negative
IDs don't collide with chat configs (-10000) or vision configs (-30000).
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
}
]
# Don't pass image_id_offset → use the module default (-20000).
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
assert all(c["id"] < -20_000 + 1 for c in cfgs)
assert all(c["id"] > -29_000_000 for c in cfgs)
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
"""Vision LLMs must accept image input AND emit text — pure image-gen
(no text out) and text-only (no image in) models are excluded.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
# GPT-4o: vision LLM (image in, text out) — must emit.
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
# Pure image generator — image *output*, no text out. Must NOT emit.
{
"id": "openai/gpt-image-1",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["image"],
},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Pure text model (no image in). Must NOT emit.
{
"id": "anthropic/claude-3-haiku",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"context_length": 200_000,
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
},
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
names = {c["model_name"] for c in cfgs}
assert names == {"openai/gpt-4o"}
cfg = cfgs[0]
assert cfg["billing_tier"] == "premium"
# Pricing carried inline so pricing_registration can register vision
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
# is cleared.
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't inherit an Azure endpoint.
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_vision_llm_configs_drops_chat_only_filters():
"""A small-context vision model that doesn't advertise tool calling is
still a valid vision LLM for "describe this image" prompts. The chat
filters (``supports_tool_calling``, ``has_sufficient_context``) must
NOT be applied to vision emission.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
{
"id": "tiny/vision-mini",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": [], # no tools
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
}
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
assert len(cfgs) == 1
assert cfgs[0]["model_name"] == "tiny/vision-mini"

View file

@ -0,0 +1,447 @@
"""Pricing registration unit tests.
The pricing-registration module is what makes ``response_cost`` populate
correctly for OpenRouter dynamic models and operator-defined Azure
deployments both of which LiteLLM doesn't natively know about. The tests
exercise:
* The alias generators emit every shape that LiteLLM's cost-callback might
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
``provider/base_model``, ``provider/model_name``, plus the special
``azure_openai`` ``azure`` normalisation).
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
with the right alias set and pricing values per provider.
* Configs without a resolvable pair of cost values are skipped never
registered as zero, since that would override pricing LiteLLM might
already know natively.
"""
from __future__ import annotations
from typing import Any
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Alias generators
# ---------------------------------------------------------------------------
def test_openrouter_alias_set_includes_prefixed_and_bare():
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
assert aliases == [
"openrouter/anthropic/claude-3-5-sonnet",
"anthropic/claude-3-5-sonnet",
]
def test_openrouter_alias_set_dedupes():
"""If the model id is already prefixed with ``openrouter/``, the alias
set must not contain duplicates that would re-register the same key
twice.
"""
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("openrouter/foo")
# The bare and prefixed variants compute to the same string here, so we
# at minimum require uniqueness.
assert len(aliases) == len(set(aliases))
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
"""``azure_openai`` (our YAML provider slug) must register under
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
(which uses provider ``azure``) finds the pricing too.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="AZURE_OPENAI",
model_name="gpt-5.4",
base_model="gpt-5.4",
)
assert "gpt-5.4" in aliases
assert "azure_openai/gpt-5.4" in aliases
assert "azure/gpt-5.4" in aliases
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
"""When ``model_name`` differs from ``base_model`` (operator labelled a
deployment), both must appear in the alias set since either may surface
in callbacks depending on the call path.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="OPENAI",
model_name="my-deployment-label",
base_model="gpt-4o",
)
assert "gpt-4o" in aliases
assert "openai/gpt-4o" in aliases
assert "my-deployment-label" in aliases
assert "openai/my-deployment-label" in aliases
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="",
model_name="foo",
base_model="bar",
)
assert "bar" in aliases
assert "foo" in aliases
assert all("/" not in a for a in aliases)
# ---------------------------------------------------------------------------
# register_pricing_from_global_configs
# ---------------------------------------------------------------------------
class _RegistrationSpy:
"""Captures the dicts passed to ``litellm.register_model``.
Many calls may go through; we just record them all and let tests assert
against the union.
"""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def __call__(self, payload: dict[str, Any]) -> None:
self.calls.append(payload)
@property
def all_keys(self) -> set[str]:
keys: set[str] = set()
for payload in self.calls:
keys.update(payload.keys())
return keys
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
spy = _RegistrationSpy()
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
spy,
raising=False,
)
return spy
def _patch_openrouter_pricing(
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
) -> None:
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
class _Stub:
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
return mapping
class _StubService:
@classmethod
def is_initialized(cls) -> bool:
return True
@classmethod
def get_instance(cls) -> _Stub:
return _Stub()
monkeypatch.setattr(
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
_StubService,
raising=False,
)
def test_openrouter_models_register_under_aliases(monkeypatch):
"""An OpenRouter config whose ``model_name`` is in the cached raw
pricing map is registered under both ``openrouter/X`` and bare ``X``.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
# Costs are float-converted from the raw OpenRouter strings.
payload = spy.calls[0]
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"input_cost_per_token"
] == pytest.approx(3e-6)
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"output_cost_per_token"
] == pytest.approx(15e-6)
assert (
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
== "openrouter"
)
def test_yaml_override_registers_under_alias_set(monkeypatch):
"""Operator-declared ``input_cost_per_token`` /
``output_cost_per_token`` on a YAML config registers under every
alias the YAML alias generator produces including the ``azure/``
normalisation for ``azure_openai`` providers.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5.4",
"litellm_params": {
"base_model": "gpt-5.4",
"input_cost_per_token": 2e-6,
"output_cost_per_token": 8e-6,
},
}
],
)
register_pricing_from_global_configs()
keys = spy.all_keys
assert "gpt-5.4" in keys
assert "azure_openai/gpt-5.4" in keys
assert "azure/gpt-5.4" in keys
payload = spy.calls[0]
entry = payload["gpt-5.4"]
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
assert entry["litellm_provider"] == "azure"
def test_no_override_means_no_registration(monkeypatch):
"""A YAML config that *omits* both pricing fields must NOT be registered
registering as zero would override LiteLLM's native pricing for the
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENAI",
"model_name": "gpt-4o",
"litellm_params": {"base_model": "gpt-4o"},
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
configured model (network blip during refresh, model added later, etc.),
we skip it rather than registering zero pricing.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_register_continues_after_individual_failure(monkeypatch, caplog):
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
must not abort registration of the remaining configs.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
successful_calls: list[dict[str, Any]] = []
def _maybe_fail(payload: dict[str, Any]) -> None:
if any(k in failing_keys for k in payload):
raise RuntimeError("boom")
successful_calls.append(payload)
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
_maybe_fail,
raising=False,
)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
},
{
"id": 2,
"provider": "OPENAI",
"model_name": "custom-deployment",
"litellm_params": {
"base_model": "custom-deployment",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 2e-6,
},
},
],
)
register_pricing_from_global_configs()
# The good config still registered.
assert any("custom-deployment" in payload for payload in successful_calls)
def test_vision_configs_registered_with_chat_shape(monkeypatch):
"""``register_pricing_from_global_configs`` walks
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
calls (during indexing) bill correctly. Vision configs use the same
chat-shape token prices, but image-gen pricing is intentionally NOT
registered here (handled via ``response_cost`` in LiteLLM).
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
)
# No chat configs — only vision. Proves the vision walk is a separate
# iteration, not piggy-backed on the chat list.
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"billing_tier": "premium",
"input_cost_per_token": 5e-6,
"output_cost_per_token": 15e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/openai/gpt-4o" in spy.all_keys
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
assert payload_value["mode"] == "chat"
assert payload_value["litellm_provider"] == "openrouter"
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
"""If the OpenRouter pricing cache misses a vision model (different
catalogue surface), the vision walk falls back to inline
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash",
"billing_tier": "premium",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 4e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys

View file

@ -0,0 +1,107 @@
"""Unit tests for the shared ``api_base`` resolver.
The cascade exists so vision and image-gen call sites can't silently
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
when an OpenRouter / Groq / etc. config ships an empty string. See
``provider_api_base`` module docstring for the original repro
(OpenRouter image-gen 404-ing against an Azure endpoint).
"""
from __future__ import annotations
import pytest
from app.services.provider_api_base import (
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
pytestmark = pytest.mark.unit
def test_config_value_wins_over_defaults():
"""A non-empty config value is always returned verbatim, even when the
provider has a default the operator gets the last word."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="https://my-openrouter-mirror.example.com/v1",
)
assert result == "https://my-openrouter-mirror.example.com/v1"
def test_provider_key_default_when_config_missing():
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
base URL the provider-key map must take precedence over the prefix
map so DeepSeek requests don't go to OpenAI."""
result = resolve_api_base(
provider="DEEPSEEK",
provider_prefix="openai",
config_api_base=None,
)
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_provider_prefix_default_when_no_key_default():
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=None,
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_unknown_provider_returns_none():
"""When neither map matches we return ``None`` so the caller can let
LiteLLM apply its own provider-integration default (Azure deployment
URL, custom-provider URL, etc.)."""
result = resolve_api_base(
provider="SOMETHING_NEW",
provider_prefix="something_new",
config_api_base=None,
)
assert result is None
def test_empty_string_config_treated_as_missing():
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
and downstream call sites use ``if cfg.get("api_base"):`` empty
strings are falsy in Python but the cascade has to step in anyway."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_whitespace_only_config_treated_as_missing():
"""A config value of ``" "`` is a configuration mistake — treat it
as missing instead of forwarding whitespace to LiteLLM (which would
almost certainly 404)."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=" ",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_provider_case_insensitive():
"""Some call sites pass the provider lowercase (DB enum value), others
uppercase (YAML key). Both must resolve."""
upper = resolve_api_base(
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
)
lower = resolve_api_base(
provider="deepseek", provider_prefix="openai", config_api_base=None
)
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_all_inputs_none_returns_none():
assert (
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
is None
)

View file

@ -0,0 +1,244 @@
"""Unit tests for the shared chat-image capability resolver.
Two resolvers, two intents:
- ``derive_supports_image_input`` best-effort True for the catalog and
selector. Default-allow on unknown / unmapped models. The streaming
task safety net never sees this value directly.
- ``is_known_text_only_chat_model`` strict opt-out for the safety net.
Returns True only when LiteLLM's model map *explicitly* sets
``supports_vision=False``. Anything else (missing key, exception,
True) returns False so the request flows through to the provider.
"""
from __future__ import annotations
import pytest
from app.services.provider_capabilities import (
derive_supports_image_input,
is_known_text_only_chat_model,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# derive_supports_image_input — OpenRouter modalities path (authoritative)
# ---------------------------------------------------------------------------
def test_or_modalities_with_image_returns_true():
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="openai/gpt-4o",
openrouter_input_modalities=["text", "image"],
)
is True
)
def test_or_modalities_text_only_returns_false():
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="deepseek/deepseek-v3.2-exp",
openrouter_input_modalities=["text"],
)
is False
)
def test_or_modalities_empty_list_returns_false():
"""OR explicitly publishing an empty modality list is a definitive
'no inputs at all' signal treat as False rather than falling back
to LiteLLM."""
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="weird/empty-modalities",
openrouter_input_modalities=[],
)
is False
)
def test_or_modalities_none_falls_through_to_litellm():
"""``None`` (missing key) is *not* a definitive signal — fall through
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
assert (
derive_supports_image_input(
provider="OPENAI",
model_name="gpt-4o",
openrouter_input_modalities=None,
)
is True
)
# ---------------------------------------------------------------------------
# derive_supports_image_input — LiteLLM model-map path
# ---------------------------------------------------------------------------
def test_litellm_known_vision_model_returns_true():
assert (
derive_supports_image_input(
provider="OPENAI",
model_name="gpt-4o",
)
is True
)
def test_litellm_base_model_wins_over_model_name():
"""Azure-style entries pass model_name=deployment_id and put the
canonical sku in litellm_params.base_model. The resolver must
consult base_model first or the deployment id (which LiteLLM
doesn't know) would shadow the real capability."""
assert (
derive_supports_image_input(
provider="AZURE_OPENAI",
model_name="my-azure-deployment-id",
base_model="gpt-4o",
)
is True
)
def test_litellm_unknown_model_default_allows():
"""Default-allow on unknown — the safety net is the actual block."""
assert (
derive_supports_image_input(
provider="CUSTOM",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
is True
)
def test_litellm_known_text_only_returns_false():
"""A model that LiteLLM explicitly knows is text-only resolves to
False even via the catalog resolver. ``deepseek-chat`` (the
DeepSeek-V3 chat sku) is in the map without supports_vision and
LiteLLM's `supports_vision` returns False."""
# Sanity: confirm the helper's negative path. We use a small model
# known not to support vision per the map.
result = derive_supports_image_input(
provider="DEEPSEEK",
model_name="deepseek-chat",
)
# We accept either False (LiteLLM said explicit no) or True
# (default-allow if the entry isn't mapped on this version) — the
# invariant is that the resolver never *raises* on a known-text-only
# provider/model. The behaviour-binding assertion lives in
# ``test_is_known_text_only_chat_model_explicit_false`` below.
assert isinstance(result, bool)
# ---------------------------------------------------------------------------
# is_known_text_only_chat_model — strict opt-out semantics
# ---------------------------------------------------------------------------
def test_is_known_text_only_returns_false_for_vision_model():
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_is_known_text_only_returns_false_for_unknown_model():
"""Strict opt-out: missing from the map ≠ text-only. The safety net
must NOT fire for an unmapped model that's the regression we're
fixing."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
is False
)
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
helper swallows the exception and returns False so the safety net
doesn't fire on a transient lookup failure."""
import app.services.provider_capabilities as pc
def _raise(**_kwargs):
raise ValueError("intentional test failure")
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
we exercise the opt-out path deterministically. Using a stub keeps
the test stable across LiteLLM map updates."""
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"supports_vision": False, "max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is True
)
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"supports_vision": True}
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is False
)
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
"""A model entry without ``supports_vision`` at all is treated as
'unknown' strict opt-out means False."""
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"max_input_tokens": 8192} # no supports_vision
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is False
)

View file

@ -0,0 +1,157 @@
"""Unit tests for ``QuotaCheckedVisionLLM``.
Validates that:
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
enforcement) and forwards the inner LLM's response on success.
* The wrapper proxies non-overridden attributes to the inner LLM
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
still work without quota gating (they're not used in indexing today).
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
bubbles it up the ETL pipeline catches that and falls back to OCR.
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
class _FakeInnerLLM:
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
def __init__(self, response: Any = "OCR'd content") -> None:
self._response = response
self.ainvoke_calls: list[Any] = []
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
self.ainvoke_calls.append(input)
return self._response
def some_other_method(self, x: int) -> int:
return x * 2
@contextlib.asynccontextmanager
async def _passthrough_billable_call(**_kwargs):
"""Stand-in for billable_call that always allows the call to run."""
class _Acc:
total_cost_micros = 0
total_prompt_tokens = 0
total_completion_tokens = 0
grand_total = 0
calls: list[Any] = []
def per_message_summary(self) -> dict[str, dict[str, int]]:
return {}
yield _Acc()
@pytest.mark.asyncio
async def test_ainvoke_routes_through_billable_call(monkeypatch):
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
captured_kwargs: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _spy_billable_call(**kwargs):
captured_kwargs.append(kwargs)
async with _passthrough_billable_call() as acc:
yield acc
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_spy_billable_call,
raising=False,
)
inner = _FakeInnerLLM(response="A red apple on a white table")
user_id = uuid4()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=user_id,
search_space_id=99,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
result = await wrapper.ainvoke([{"text": "what is this?"}])
assert result == "A red apple on a white table"
assert len(inner.ainvoke_calls) == 1
assert len(captured_kwargs) == 1
bc_kwargs = captured_kwargs[0]
assert bc_kwargs["user_id"] == user_id
assert bc_kwargs["search_space_id"] == 99
assert bc_kwargs["billing_tier"] == "premium"
assert bc_kwargs["base_model"] == "openai/gpt-4o"
assert bc_kwargs["quota_reserve_tokens"] == 4000
assert bc_kwargs["usage_type"] == "vision_extraction"
@pytest.mark.asyncio
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
from app.services.billable_calls import QuotaInsufficientError
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
@contextlib.asynccontextmanager
async def _denying_billable_call(**_kwargs):
raise QuotaInsufficientError(
usage_type="vision_extraction",
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield # unreachable but required for asynccontextmanager type
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_denying_billable_call,
raising=False,
)
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
with pytest.raises(QuotaInsufficientError):
await wrapper.ainvoke([{"text": "x"}])
# Inner LLM never ran on a denied reservation.
assert inner.ainvoke_calls == []
@pytest.mark.asyncio
async def test_proxies_non_overridden_attributes_to_inner():
"""``__getattr__`` forwards anything not on the proxy itself, so any
method we didn't explicitly override (``invoke``, ``astream``,
``with_structured_output``, etc.) still works just without quota
gating, which is fine because the indexer only ever calls ainvoke.
"""
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
# ``some_other_method`` is on the inner only.
assert wrapper.some_other_method(7) == 14

View file

@ -0,0 +1,281 @@
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
Capability is sourced from two places, in order of preference:
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
(authoritative OpenRouter publishes per-model modalities directly).
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
YAML / BYOK configs that don't carry an explicit operator override.
The catalog default is *True* (conservative-allow): an unknown / unmapped
model is not pre-judged. The streaming-task safety net
(``is_known_text_only_chat_model``) is the only place a False actually
blocks a request and it requires LiteLLM to *explicitly* mark the model
as text-only.
"""
from __future__ import annotations
import pytest
from app.services.openrouter_integration_service import (
_OPENROUTER_DYNAMIC_MARKER,
_generate_configs,
_supports_image_input,
)
pytestmark = pytest.mark.unit
_SETTINGS_BASE: dict = {
"api_key": "sk-or-test",
"id_offset": -10_000,
"rpm": 200,
"tpm": 1_000_000,
"free_rpm": 20,
"free_tpm": 100_000,
"anonymous_enabled_paid": False,
"anonymous_enabled_free": True,
"quota_reserve_tokens": 4000,
}
# ---------------------------------------------------------------------------
# _supports_image_input helper (OpenRouter modalities)
# ---------------------------------------------------------------------------
def test_supports_image_input_true_for_multimodal():
assert (
_supports_image_input(
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
}
)
is True
)
def test_supports_image_input_false_for_text_only():
"""The exact failure mode the safety net guards against — DeepSeek V3
is a text-in/text-out model and would 404 if forwarded image_url."""
assert (
_supports_image_input(
{
"id": "deepseek/deepseek-v3.2-exp",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
}
)
is False
)
def test_supports_image_input_false_when_modalities_missing():
"""Defensive: missing architecture is treated as text-only at the
OpenRouter helper level. The wider catalog resolver
(`derive_supports_image_input`) only consults modalities when they
are non-empty, otherwise it falls back to LiteLLM."""
assert _supports_image_input({"id": "weird/model"}) is False
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
assert (
_supports_image_input(
{"id": "weird/model", "architecture": {"input_modalities": None}}
)
is False
)
# ---------------------------------------------------------------------------
# _generate_configs threads the flag onto every emitted chat config
# ---------------------------------------------------------------------------
def test_generate_configs_emits_supports_image_input():
raw = [
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
{
"id": "deepseek/deepseek-v3.2-exp",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
},
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
by_model = {c["model_name"]: c for c in cfgs}
gpt = by_model["openai/gpt-4o"]
assert gpt["supports_image_input"] is True
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
assert deepseek["supports_image_input"] is False
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
# ---------------------------------------------------------------------------
# YAML loader: defer to derive_supports_image_input on unannotated entries
# ---------------------------------------------------------------------------
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
"""The regression case: an Azure GPT-5.x YAML entry without a
``supports_image_input`` override should resolve to True via LiteLLM's
model map (which says ``supports_vision: true``). Previously this
defaulted to False, blocking every image turn for vision-capable
YAML configs."""
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -2
name: Azure GPT-4o
provider: AZURE_OPENAI
model_name: gpt-4o
api_key: sk-test
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
assert configs[0]["supports_image_input"] is True
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -1
name: GPT-4o
provider: OPENAI
model_name: gpt-4o
api_key: sk-test
supports_image_input: false
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
# Operator override always wins, even against LiteLLM's True.
assert configs[0]["supports_image_input"] is False
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
"""Unknown / unmapped model in YAML: default-allow. The streaming
safety net (which requires an explicit-False from LiteLLM) is the
only place a real block happens, so we don't lock the user out of
a freshly added third-party entry the catalog can't introspect."""
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -1
name: Some Brand New Model
provider: CUSTOM
custom_provider: brand_new_proxy
model_name: brand-new-model-x9
api_key: sk-test
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
assert configs[0]["supports_image_input"] is True
# ---------------------------------------------------------------------------
# AgentConfig threads the flag through both YAML and Auto / BYOK
# ---------------------------------------------------------------------------
def test_agent_config_from_yaml_explicit_overrides_resolver():
from app.agents.new_chat.llm_config import AgentConfig
cfg_text_only = AgentConfig.from_yaml_config(
{
"id": -1,
"name": "Text Only Override",
"provider": "openai",
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
"api_key": "sk-test",
"supports_image_input": False,
}
)
cfg_explicit_vision = AgentConfig.from_yaml_config(
{
"id": -2,
"name": "GPT-4o",
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
"supports_image_input": True,
}
)
assert cfg_text_only.supports_image_input is False
assert cfg_explicit_vision.supports_image_input is True
def test_agent_config_from_yaml_unannotated_uses_resolver():
"""Without an explicit YAML key, AgentConfig defers to the catalog
resolver for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
from app.agents.new_chat.llm_config import AgentConfig
cfg = AgentConfig.from_yaml_config(
{
"id": -1,
"name": "GPT-4o (no override)",
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
}
)
assert cfg.supports_image_input is True
def test_agent_config_auto_mode_supports_image_input():
"""Auto routes across the pool. We optimistically allow image input
so users can keep their selection on Auto with a vision-capable
deployment somewhere in the pool. The router's own `allowed_fails`
handles non-vision deployments via fallback."""
from app.agents.new_chat.llm_config import AgentConfig
auto = AgentConfig.from_auto_mode()
assert auto.supports_image_input is True

View file

@ -0,0 +1,515 @@
"""Cost-based premium quota unit tests.
Covers the USD-micro behaviour added in migration 140:
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
calls in a turn used as the debit amount when ``agent_config.is_premium``
is true, regardless of which underlying model produced each call. This
preserves the prior "premium turn → all calls in turn count" rule from the
token-based system.
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
clamps to a sane floor when pricing is unknown, and respects the
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
can't lock the whole balance on one call.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# TurnTokenAccumulator — premium-turn debit semantics
# ---------------------------------------------------------------------------
def test_total_cost_micros_sums_premium_and_free_calls():
"""A premium turn that also called a free sub-agent debits the union.
The plan deliberately preserved the existing "premium turn → all calls
count" behaviour because per-call premium filtering relied on
``LLMRouterService._premium_model_strings`` which only covers router-pool
deployments. ``total_cost_micros`` therefore must include free-model
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
call's actual provider cost.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
# Premium model (e.g. claude-opus): non-zero cost.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=1200,
completion_tokens=400,
total_tokens=1600,
cost_micros=12_345,
)
# Free sub-agent (e.g. title-gen on a free model): zero cost.
acc.add(
model="gpt-4o-mini",
prompt_tokens=120,
completion_tokens=20,
total_tokens=140,
cost_micros=0,
)
# A second premium-priced call within the same turn.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=800,
completion_tokens=200,
total_tokens=1000,
cost_micros=7_500,
)
assert acc.total_cost_micros == 12_345 + 0 + 7_500
# Token totals stay correct so the FE display path still works.
assert acc.grand_total == 1600 + 140 + 1000
def test_total_cost_micros_zero_when_no_calls():
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
assert acc.total_cost_micros == 0
assert acc.grand_total == 0
def test_per_message_summary_groups_cost_by_model():
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
SSE ``model_breakdown`` payload reports actual USD spend per provider.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_micros=4_000,
)
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=200,
completion_tokens=100,
total_tokens=300,
cost_micros=8_000,
)
acc.add(
model="gpt-4o-mini",
prompt_tokens=50,
completion_tokens=10,
total_tokens=60,
cost_micros=200,
)
summary = acc.per_message_summary()
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
assert summary["gpt-4o-mini"]["cost_micros"] == 200
def test_serialized_calls_includes_cost_micros():
"""``serialized_calls`` is what flows into the SSE ``call_details``
payload; cost_micros must be present on each entry so the FE message-info
dropdown can render per-call USD.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="m",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=42,
)
serialized = acc.serialized_calls()
assert serialized == [
{
"model": "m",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost_micros": 42,
"call_kind": "chat",
}
]
# ---------------------------------------------------------------------------
# estimate_call_reserve_micros — sizing and clamping
# ---------------------------------------------------------------------------
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
helper falls back to the 100-micro floor small enough that a user with
$0.0001 left can still send a tiny request, but non-zero so we still gate
against an empty balance.
"""
import litellm
from app.services import token_quota_service
def _raise(_name):
raise KeyError("unknown")
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="nonexistent-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
assert micros == 100
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
"""LiteLLM may *return* a model with both cost-per-token fields at 0
(pricing not yet registered). The helper must not multiply 0 x tokens
and end up reserving 0 it must clamp to the floor.
"""
import litellm
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
raising=False,
)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="some-pending-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
def test_reserve_scales_with_model_cost(monkeypatch):
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
some small artificial cap that was the bug the plan called out.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 15e-6,
"output_cost_per_token": 75e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="claude-3-opus",
quota_reserve_tokens=4000,
)
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
assert micros == 360_000
def test_reserve_clamps_to_max_ceiling(monkeypatch):
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
nominally compute to $4 = 4_000_000 micros. The ceiling
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
can't lock the user's whole balance on one call.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-3,
"output_cost_per_token": 0,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="oops-misconfigured",
quota_reserve_tokens=4000,
)
assert micros == 1_000_000
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
so anonymous-style configs still reserve the operator-tunable default.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-6,
"output_cost_per_token": 1e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=None
)
== 4000
)
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=0
)
== 4000
)
# ---------------------------------------------------------------------------
# TokenTrackingCallback — image vs chat usage shape
# ---------------------------------------------------------------------------
class _FakeImageUsage:
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
def __init__(
self,
input_tokens: int = 0,
output_tokens: int = 0,
total_tokens: int | None = None,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
if total_tokens is not None:
self.total_tokens = total_tokens
class _FakeImageResponse:
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
``type(...).__name__`` probe routes to the image branch.
"""
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
self.usage = usage
if response_cost is not None:
self._hidden_params = {"response_cost": response_cost}
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
# the callback. We can't simply name the class ``ImageResponse`` because
# the test runner sometimes imports test modules in surprising ways and
# we want to be explicit.
_FakeImageResponse.__name__ = "ImageResponse"
class _FakeChatUsage:
def __init__(self, prompt: int, completion: int):
self.prompt_tokens = prompt
self.completion_tokens = completion
self.total_tokens = prompt + completion
class _FakeChatResponse:
def __init__(self, usage: _FakeChatUsage):
self.usage = usage
@pytest.mark.asyncio
async def test_callback_reads_image_usage_input_output_tokens():
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
prompt_tokens/completion_tokens which is the chat shape.
"""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
response_cost=0.04, # $0.04 per image
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 42
assert call.completion_tokens == 8
assert call.total_tokens == 50
# 0.04 USD = 40_000 micros
assert call.cost_micros == 40_000
assert call.call_kind == "image_generation"
@pytest.mark.asyncio
async def test_callback_chat_path_unchanged():
"""Chat responses must still read prompt_tokens/completion_tokens."""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={
"model": "openrouter/anthropic/claude-3-5-sonnet",
"response_cost": 0.0036,
},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 120
assert call.completion_tokens == 30
assert call.total_tokens == 150
assert call.cost_micros == 3_600
assert call.call_kind == "chat"
@pytest.mark.asyncio
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
"""When OpenRouter omits ``usage.cost`` LiteLLM's
``default_image_cost_calculator`` raises. The defensive image branch in
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
chat-shaped and would raise too) it returns 0 with a WARNING log.
"""
import litellm
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
# Force completion_cost to raise the same way OpenRouter image-gen fails.
def _boom(*_args, **_kwargs):
raise ValueError("model_cost: missing entry for openrouter image model")
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
# And make sure cost_per_token is NEVER called for the image path —
# if it were, our ``is_image=True`` branch is broken.
cost_per_token_calls: list = []
def _record_cost_per_token(**kwargs):
cost_per_token_calls.append(kwargs)
return (0.0, 0.0)
monkeypatch.setattr(
litellm, "cost_per_token", _record_cost_per_token, raising=False
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
assert acc.calls[0].cost_micros == 0
assert acc.calls[0].call_kind == "image_generation"
# The image branch must short-circuit before cost_per_token.
assert cost_per_token_calls == []
# ---------------------------------------------------------------------------
# scoped_turn — ContextVar reset semantics (issue B)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_scoped_turn_restores_outer_accumulator():
"""``scoped_turn`` must restore the previous ContextVar value on exit
so a per-call wrapper inside an outer chat turn doesn't leak its
accumulator outward (which would cause double-debit at chat-turn exit).
"""
from app.services.token_tracking_service import (
get_current_accumulator,
scoped_turn,
start_turn,
)
outer = start_turn()
assert get_current_accumulator() is outer
async with scoped_turn() as inner:
assert get_current_accumulator() is inner
assert inner is not outer
inner.add(
model="x",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=5,
)
# After exit the outer accumulator is restored unchanged.
assert get_current_accumulator() is outer
assert outer.total_cost_micros == 0
assert len(outer.calls) == 0
# The inner accumulator captured the call but didn't bleed into outer.
assert inner.total_cost_micros == 5
@pytest.mark.asyncio
async def test_scoped_turn_resets_to_none_when_no_outer():
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
indexing job) must leave the ContextVar at ``None`` on exit so the
next *unrelated* request starts clean.
"""
from app.services.token_tracking_service import (
_turn_accumulator,
get_current_accumulator,
scoped_turn,
)
# ContextVar default is None for a fresh test isolated context. We
# simulate "no outer" explicitly to be robust against test order.
token = _turn_accumulator.set(None)
try:
assert get_current_accumulator() is None
async with scoped_turn() as acc:
assert get_current_accumulator() is acc
assert get_current_accumulator() is None
finally:
_turn_accumulator.reset(token)

View file

@ -0,0 +1,89 @@
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
defaults from ``litellm.api_base`` either.
Vision shares the same shape as image-gen global YAML / OpenRouter
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
call sites would silently drop the empty string and inherit
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
construction so we test the kwargs we hand to it instead.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_get_vision_llm_global_openrouter_sets_api_base():
"""Global negative-ID branch: an OpenRouter vision config with
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
``api_base="https://openrouter.ai/api/v1"`` never an empty string,
never silently absent."""
from app.services import llm_service
cfg = {
"id": -30_001,
"name": "GPT-4o Vision (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
"billing_tier": "free",
}
search_space = MagicMock()
search_space.id = 1
search_space.user_id = "user-x"
search_space.vision_llm_config_id = cfg["id"]
session = AsyncMock()
scalars = MagicMock()
scalars.first.return_value = search_space
result = MagicMock()
result.scalars.return_value = scalars
session.execute.return_value = result
captured: dict = {}
class FakeSanitized:
def __init__(self, **kwargs):
captured.update(kwargs)
with (
patch(
"app.services.vision_llm_router_service.get_global_vision_llm_config",
return_value=cfg,
),
patch(
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
new=FakeSanitized,
),
):
await llm_service.get_vision_llm(session=session, search_space_id=1)
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-4o"
def test_vision_router_deployment_sets_api_base_when_config_empty():
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
so the resolver has to apply at deployment construction time too."""
from app.services.vision_llm_router_service import VisionLLMRouterService
deployment = VisionLLMRouterService._config_to_deployment(
{
"model_name": "openai/gpt-4o",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"

View file

@ -51,22 +51,34 @@ class _FakeToolMessage:
tool_call_id: str | None = None
@dataclass
class _FakeInterrupt:
value: dict[str, Any]
@dataclass
class _FakeTask:
interrupts: tuple[_FakeInterrupt, ...] = ()
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
def __init__(self, tasks: list[Any] | None = None) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
self.tasks: list[Any] = tasks or []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
def __init__(
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> None:
self._events = events
self._state = state or _FakeAgentState()
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
@ -79,7 +91,7 @@ class _FakeAgent:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
return self._state
def _model_stream(
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
async def _drain(
events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
agent = _FakeAgent(events, state=state)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id(
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"
@pytest.mark.asyncio
async def test_interrupt_request_uses_task_that_contains_interrupt(
parity_v2_on: None,
) -> None:
interrupt_payload = {
"type": "calendar_event_create",
"action": {
"tool": "create_calendar_event",
"params": {"summary": "mom bday"},
},
"context": {},
}
state = _FakeAgentState(
tasks=[
_FakeTask(interrupts=()),
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
]
)
payloads = await _drain([], state=state)
interrupts = _of_type(payloads, "data-interrupt-request")
assert len(interrupts) == 1
assert (
interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
)

View file

@ -0,0 +1,318 @@
"""Regression tests for ``run_async_celery_task``.
These tests pin down the production bug observed on 2026-05-02 where
the video-presentation Celery task hung at ``[billable_call] finalize``
because the shared ``app.db.engine`` had pooled asyncpg connections
bound to a *previous* task's now-closed event loop. Reusing such a
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
AttributeError: 'NoneType' object has no attribute 'send'
(the proactor is None because the loop is gone) and can hang forever
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
The fix is ``run_async_celery_task``: a small helper that runs every
async celery task body inside a fresh event loop and disposes the
shared engine pool both before (defends against a previous task that
crashed) and after (releases connections we opened on this loop).
Tests here exercise the helper with a stub engine that records
``dispose()`` calls and panics if a coroutine produced by one loop is
awaited on another mirroring the real asyncpg behaviour.
"""
from __future__ import annotations
import asyncio
import gc
import sys
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import patch
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Stub engine that emulates the asyncpg-on-stale-loop crash
# ---------------------------------------------------------------------------
class _StaleLoopEngine:
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
the running event loop id so tests can assert it ran on *each*
fresh loop.
"""
def __init__(self) -> None:
self.dispose_loop_ids: list[int] = []
async def dispose(self) -> None:
loop = asyncio.get_running_loop()
self.dispose_loop_ids.append(id(loop))
@contextmanager
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
"""Patch ``from app.db import engine as shared_engine`` lookup.
The helper imports lazily inside the function body, so we have to
patch the attribute on the already-loaded ``app.db`` module.
"""
import app.db as app_db
original = getattr(app_db, "engine", None)
app_db.engine = stub # type: ignore[attr-defined]
try:
yield
finally:
if original is None:
with pytest.raises(AttributeError):
_ = app_db.engine
else:
app_db.engine = original # type: ignore[attr-defined]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
"""Happy path: the coroutine result is returned, and the shared
engine is disposed both before and after the task body runs.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _body() -> str:
# Engine should already have been disposed once before we run.
assert len(stub.dispose_loop_ids) == 1
return "ok"
with _patch_shared_engine(stub):
result = run_async_celery_task(_body)
assert result == "ok"
# Once before the body, once after (in finally).
assert len(stub.dispose_loop_ids) == 2
# Both disposes ran on the SAME (fresh) loop the task body used.
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
def test_runner_creates_fresh_loop_per_invocation() -> None:
"""Each call must spin its own loop. Without this guarantee a
previous task's loop would be reused and the asyncpg-stale-loop
crash would never be avoided.
"""
import app.tasks.celery_tasks as celery_tasks_pkg
stub = _StaleLoopEngine()
new_loop_calls = 0
closed_loops: list[bool] = []
real_new_event_loop = asyncio.new_event_loop
def _counting_new_loop() -> asyncio.AbstractEventLoop:
nonlocal new_loop_calls
new_loop_calls += 1
loop = real_new_event_loop()
# Hook close() so we can verify each loop was closed properly
# before the next one was created.
original_close = loop.close
def _tracked_close() -> None:
closed_loops.append(True)
original_close()
loop.close = _tracked_close # type: ignore[method-assign]
return loop
async def _body() -> None:
# Loop is alive and current at body execution time.
running = asyncio.get_running_loop()
assert not running.is_closed()
with (
_patch_shared_engine(stub),
patch.object(asyncio, "new_event_loop", _counting_new_loop),
):
for _ in range(3):
celery_tasks_pkg.run_async_celery_task(_body)
assert new_loop_calls == 3
assert closed_loops == [True, True, True]
# Each invocation disposed twice (before + after).
assert len(stub.dispose_loop_ids) == 6
def test_runner_disposes_engine_even_when_body_raises() -> None:
"""Cleanup MUST run on the failure path too — otherwise stale
connections leak into the next task and cause the original hang.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
class _BoomError(RuntimeError):
pass
async def _body() -> None:
raise _BoomError("kaboom")
with _patch_shared_engine(stub), pytest.raises(_BoomError):
run_async_celery_task(_body)
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
def test_runner_swallows_dispose_errors() -> None:
"""A flaky engine.dispose() must NEVER take down a celery task.
Production scenario: the very first dispose (before the body runs)
might hit a partially-initialised engine; the helper logs and
moves on. The task body still runs; the result is still returned.
"""
from app.tasks.celery_tasks import run_async_celery_task
class _AngryEngine:
def __init__(self) -> None:
self.calls = 0
async def dispose(self) -> None:
self.calls += 1
raise RuntimeError("dispose() blew up")
stub = _AngryEngine()
async def _body() -> int:
return 42
with _patch_shared_engine(stub):
assert run_async_celery_task(_body) == 42
assert stub.calls == 2 # before + after both attempted
def test_runner_propagates_value_from_async_body() -> None:
"""Sanity: pass-through of any pickleable celery return value."""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _body() -> dict[str, object]:
return {"status": "ready", "video_presentation_id": 19}
with _patch_shared_engine(stub):
out = run_async_celery_task(_body)
assert out == {"status": "ready", "video_presentation_id": 19}
def test_video_presentation_task_uses_runner_helper() -> None:
"""Defence-in-depth: confirm the celery task module imports
``run_async_celery_task``. If a future refactor inlines a
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
the original hang will return.
"""
# The module's task body should not contain a manual new_event_loop
# call — that's exactly what the helper exists to centralise.
import inspect
from app.tasks.celery_tasks import video_presentation_tasks
src = inspect.getsource(video_presentation_tasks)
assert "run_async_celery_task" in src, (
"video_presentation_tasks.py must use run_async_celery_task; "
"manual asyncio.new_event_loop() in a celery task hangs on the "
"shared SQLAlchemy pool when reused across tasks."
)
assert "asyncio.new_event_loop" not in src, (
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
"call — route every async task through run_async_celery_task to "
"avoid the stale-pool hang."
)
def test_podcast_task_uses_runner_helper() -> None:
"""Symmetric assertion for the podcast task — same root cause, same
fix, same regression risk.
"""
import inspect
from app.tasks.celery_tasks import podcast_tasks
src = inspect.getsource(podcast_tasks)
assert "run_async_celery_task" in src
assert "asyncio.new_event_loop" not in src
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
"""If the task body created any async generators that didn't get
fully iterated, we must still call ``loop.shutdown_asyncgens()``
before closing otherwise we leak event-loop bound resources
that re-emerge as ``RuntimeError: Event loop is closed`` later.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _agen():
try:
yield 1
yield 2
finally:
pass
async def _body() -> None:
# Iterate the agen partially, then leave it dangling — exactly
# the situation shutdown_asyncgens() is designed to clean up.
async for v in _agen():
if v == 1:
break
with _patch_shared_engine(stub):
run_async_celery_task(_body)
# By the time the helper returns, garbage collection + shutdown_asyncgens
# should have ensured no live async-gen references remain. We don't
# assert agen.closed directly (it depends on GC ordering); the real
# contract is "no warnings, no event-loop-closed errors". A successful
# second invocation proves the loop was cleaned up properly.
with _patch_shared_engine(stub):
run_async_celery_task(_body)
# Force a GC pass to surface any 'coroutine was never awaited'
# warnings that would indicate the cleanup is broken.
gc.collect()
def test_runner_uses_proactor_loop_on_windows() -> None:
"""On Windows the celery worker preselects a Proactor policy so
subprocess (ffmpeg) calls work. The helper must not silently fall
back to a Selector loop and re-break video/podcast generation.
"""
if not sys.platform.startswith("win"):
pytest.skip("Windows-specific event-loop policy assertion")
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
# Mirror the policy set at the top of every Windows celery task.
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
observed: list[str] = []
async def _body() -> None:
observed.append(type(asyncio.get_running_loop()).__name__)
with _patch_shared_engine(stub):
run_async_celery_task(_body)
assert observed == ["ProactorEventLoop"]

View file

@ -0,0 +1,388 @@
"""Unit tests for podcast Celery task billing integration.
Validates ``_generate_content_podcast`` correctly wraps
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
search-space owner's billing decision, and degrades cleanly when the
resolver fails or premium credit is exhausted.
Coverage:
* Happy-path free config: resolver ``billable_call`` enters with
``usage_type='podcast_generation'`` and the configured reserve override,
graph runs, podcast row flips to ``READY``.
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError``
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
carries ``reason='premium_quota_exhausted'``.
* Resolver failure: ``ValueError`` from the resolver podcast row flips
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
"""
from __future__ import annotations
import contextlib
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
def filter(self, *_args, **_kwargs):
return self
class _FakeSession:
def __init__(self, podcast):
self._podcast = podcast
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self._podcast)
async def commit(self):
self.commit_count += 1
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return None
class _FakeSessionMaker:
def __init__(self, session: _FakeSession):
self._session = session
def __call__(self):
return self._session
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
inside helpers keeps this fixture cheap."""
return SimpleNamespace(
id=podcast_id,
title="Test Podcast",
thread_id=thread_id,
status=None,
podcast_transcript=None,
file_location=None,
)
@contextlib.asynccontextmanager
async def _ok_billable_call(**kwargs):
"""Stand-in for ``billable_call`` that records its kwargs and yields a
no-op accumulator-shaped object."""
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
_CALL_LOG: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _denying_billable_call(**kwargs):
from app.services.billable_calls import QuotaInsufficientError
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover — for grammar only
@contextlib.asynccontextmanager
async def _settlement_failing_billable_call(**kwargs):
from app.services.billable_calls import BillingSettlementError
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
raise BillingSettlementError(
usage_type=kwargs.get("usage_type", "?"),
user_id=kwargs["user_id"],
cause=RuntimeError("finalize failed"),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_call_log():
_CALL_LOG.clear()
yield
_CALL_LOG.clear()
@pytest.mark.asyncio
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
"""Happy path: free billing tier still wraps the graph call so the
audit row is recorded. Verifies kwargs threading."""
from app.config import config as app_config
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=7, thread_id=99)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
assert search_space_id == 555
assert thread_id == 99
return user_id, "free", "openrouter/some-free-model"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {
"podcast_transcript": [
SimpleNamespace(speaker_id=0, dialog="Hi"),
SimpleNamespace(speaker_id=1, dialog="Hello"),
],
"final_podcast_file_path": "/tmp/podcast.wav",
}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=7,
source_content="hello world",
search_space_id=555,
user_prompt="make it short",
)
assert result["status"] == "ready"
assert result["podcast_id"] == 7
assert podcast.status == PodcastStatus.READY
assert podcast.file_location == "/tmp/podcast.wav"
assert len(_CALL_LOG) == 1
call = _CALL_LOG[0]
assert call["user_id"] == user_id
assert call["search_space_id"] == 555
assert call["billing_tier"] == "free"
assert call["base_model"] == "openrouter/some-free-model"
assert call["usage_type"] == "podcast_generation"
assert (
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
)
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
# FK to avoid coupling Celery audit commits to an active chat transaction.
assert "thread_id" not in call
assert call["call_details"] == {
"podcast_id": 7,
"title": "Test Podcast",
"thread_id": 99,
}
assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
"""Premium resolution flows through to ``billable_call`` so the
reserve/finalize path triggers."""
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast()
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return user_id, "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
await podcast_tasks._generate_content_podcast(
podcast_id=7,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert _CALL_LOG[0]["billing_tier"] == "premium"
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
@pytest.mark.asyncio
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
"""When ``billable_call`` denies the reservation, the graph never
runs and the podcast row flips to FAILED with the documented reason
code."""
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=8)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=8,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 8,
"reason": "premium_quota_exhausted",
}
assert podcast.status == PodcastStatus.FAILED
assert graph_invoked == [] # Graph never ran on denied reservation.
@pytest.mark.asyncio
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=10)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(
podcast_tasks, "billable_call", _settlement_failing_billable_call
)
async def _fake_graph_invoke(state, config):
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=10,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 10,
"reason": "billing_settlement_failed",
}
assert podcast.status == PodcastStatus.FAILED
@pytest.mark.asyncio
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
"""If the resolver raises (e.g. search-space deleted), the task fails
cleanly without invoking the graph."""
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=9)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
raise ValueError("Search space 555 not found")
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=9,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 9,
"reason": "billing_resolution_failed",
}
assert podcast.status == PodcastStatus.FAILED
assert graph_invoked == []

View file

@ -0,0 +1,119 @@
"""Predicate-level test for the chat streaming safety net.
The safety net in ``stream_new_chat`` rejects an image turn early with
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
selected model is *known* to be text-only. The earlier round of this
work used a strict opt-in flag (``supports_image_input`` defaulting to
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
deployments this is the regression we're fixing.
The new predicate is :func:`is_known_text_only_chat_model`, which
returns True only when LiteLLM's authoritative model map *explicitly*
sets ``supports_vision=False``. Anything else (vision True, missing
key, exception) returns False so the request flows through to the
provider.
We exercise the predicate directly here rather than driving the full
``stream_new_chat`` generator covering the gate in isolation keeps
the test focused on the regression while the generator's wider behavior
is exercised by the integration suite.
"""
from __future__ import annotations
import pytest
from app.services.provider_capabilities import is_known_text_only_chat_model
pytestmark = pytest.mark.unit
def test_safety_net_does_not_fire_for_azure_gpt_4o():
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
vision-capable per LiteLLM's model map. The previous round's
blanket-False default blocked it; the new predicate must NOT mark
it text-only."""
assert (
is_known_text_only_chat_model(
provider="AZURE_OPENAI",
model_name="my-azure-deployment",
base_model="gpt-4o",
)
is False
)
def test_safety_net_does_not_fire_for_unknown_model():
"""Default-pass on unknown — the safety net only blocks definitive
text-only confirmations. A freshly added third-party model that
LiteLLM doesn't know about must flow through to the provider."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
custom_provider="brand_new_proxy",
model_name="brand-new-model-x9",
)
is False
)
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
"""Transient ``litellm.get_model_info`` exception ≠ block. The
helper swallows the error and treats it as 'unknown' False."""
import app.services.provider_capabilities as pc
def _raise(**_kwargs):
raise RuntimeError("intentional test failure")
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
"""Stub LiteLLM to assert the only path that returns True is the
explicit ``supports_vision=False`` case. Anything else (True,
None, missing key) returns False from the predicate."""
import app.services.provider_capabilities as pc
def _info_explicit_false(**_kwargs):
return {"supports_vision": False, "max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="text-only-stub",
)
is True
)
def _info_true(**_kwargs):
return {"supports_vision": True}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="vision-stub",
)
is False
)
def _info_missing(**_kwargs):
return {"max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="missing-key-stub",
)
is False
)

View file

@ -0,0 +1,398 @@
"""Unit tests for video-presentation Celery task billing integration.
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
Validates the same wrap-graph-in-billable_call pattern and ensures the
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
threaded through.
Coverage:
* Free config: graph runs, ``billable_call`` invoked with the video
reserve override.
* Premium config: same wiring with ``billing_tier='premium'``.
* Quota denial: graph not invoked, row FAILED, reason code surfaced.
* Resolver failure: row FAILED with ``billing_resolution_failed``.
"""
from __future__ import annotations
import contextlib
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
def filter(self, *_args, **_kwargs):
return self
class _FakeSession:
def __init__(self, video):
self._video = video
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self._video)
async def commit(self):
self.commit_count += 1
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return None
class _FakeSessionMaker:
def __init__(self, session: _FakeSession):
self._session = session
def __call__(self):
return self._session
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
return SimpleNamespace(
id=video_id,
title="Test Presentation",
thread_id=thread_id,
status=None,
slides=None,
scene_codes=None,
)
_CALL_LOG: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _ok_billable_call(**kwargs):
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
@contextlib.asynccontextmanager
async def _denying_billable_call(**kwargs):
from app.services.billable_calls import QuotaInsufficientError
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover
@contextlib.asynccontextmanager
async def _settlement_failing_billable_call(**kwargs):
from app.services.billable_calls import BillingSettlementError
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
raise BillingSettlementError(
usage_type=kwargs.get("usage_type", "?"),
user_id=kwargs["user_id"],
cause=RuntimeError("finalize failed"),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_call_log():
_CALL_LOG.clear()
yield
_CALL_LOG.clear()
@pytest.mark.asyncio
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
from app.config import config as app_config
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=11, thread_id=99)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
assert search_space_id == 777
assert thread_id == 99
return user_id, "free", "openrouter/some-free-model"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=11,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result["status"] == "ready"
assert result["video_presentation_id"] == 11
assert video.status == VideoPresentationStatus.READY
assert len(_CALL_LOG) == 1
call = _CALL_LOG[0]
assert call["user_id"] == user_id
assert call["search_space_id"] == 777
assert call["billing_tier"] == "free"
assert call["base_model"] == "openrouter/some-free-model"
assert call["usage_type"] == "video_presentation_generation"
assert (
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
)
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
# FK to avoid coupling Celery audit commits to an active chat transaction.
assert "thread_id" not in call
assert call["call_details"] == {
"video_presentation_id": 11,
"title": "Test Presentation",
"thread_id": 99,
}
assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video()
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return user_id, "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
await video_presentation_tasks._generate_video_presentation(
video_presentation_id=11,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert _CALL_LOG[0]["billing_tier"] == "premium"
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
@pytest.mark.asyncio
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=12)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(
video_presentation_tasks, "billable_call", _denying_billable_call
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=12,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 12,
"reason": "premium_quota_exhausted",
}
assert video.status == VideoPresentationStatus.FAILED
assert graph_invoked == []
@pytest.mark.asyncio
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=14)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(
video_presentation_tasks,
"billable_call",
_settlement_failing_billable_call,
)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=14,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 14,
"reason": "billing_settlement_failed",
}
assert video.status == VideoPresentationStatus.FAILED
@pytest.mark.asyncio
async def test_resolver_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=13)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
raise ValueError("Search space 777 not found")
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_failing_resolver,
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=13,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 13,
"reason": "billing_resolution_failed",
}
assert video.status == VideoPresentationStatus.FAILED
assert graph_invoked == []

View file

@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model():
await _preflight_llm(fake_llm)
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_swallows_exceptions():
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
caller can safely re-touch the request-scoped session afterwards.
The helper guards the parallel preflight + agent-build path: when the
speculative build is being discarded (429 or non-429 preflight failure)
we await it solely to release any in-flight ``AsyncSession`` usage
the build's outcome is irrelevant. Any exception (including
``CancelledError``) leaking out would skip the caller's recovery flow
and re-introduce the very session-concurrency hazard the helper exists
to prevent.
"""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _raises() -> None:
raise RuntimeError("speculative build crashed")
async def _succeeds() -> str:
return "agent"
async def _slow() -> None:
await asyncio.sleep(0.05)
for coro in (_raises(), _succeeds(), _slow()):
task = asyncio.create_task(coro)
await _settle_speculative_agent_build(task)
assert task.done()
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_handles_already_done_task():
"""Done tasks (success or failure) must still be settled without raising."""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _ok() -> str:
return "ok"
async def _bad() -> None:
raise ValueError("nope")
ok_task = asyncio.create_task(_ok())
bad_task = asyncio.create_task(_bad())
# Drive both to completion before settling.
await asyncio.sleep(0)
await asyncio.sleep(0)
await _settle_speculative_agent_build(ok_task)
await _settle_speculative_agent_build(bad_task)
assert ok_task.result() == "ok"
# ``bad_task`` exception was consumed by the settle helper; calling
# ``.exception()`` after the fact must still return the original error
# (the helper observes it but doesn't clear it).
assert isinstance(bad_task.exception(), ValueError)
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(