mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 07:12:39 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/memory-extraction
This commit is contained in:
commit
b981b51ab1
176 changed files with 20407 additions and 6258 deletions
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
"""Regression tests for the compiled-agent cache.
|
||||
|
||||
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
|
||||
build-failure non-caching) and the cache-key signature helpers that
|
||||
``create_surfsense_deep_agent`` relies on. The integration with
|
||||
``create_surfsense_deep_agent`` is covered separately by the streaming
|
||||
contract tests; this module focuses on the primitives so a regression
|
||||
in the cache implementation is caught before it reaches the agent
|
||||
factory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
reload_for_tests,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stable_hash + signature helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stable_hash_is_deterministic_across_calls() -> None:
|
||||
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_stable_hash_changes_when_any_part_changes() -> None:
|
||||
base = stable_hash("v1", 42, "thread-9")
|
||||
assert stable_hash("v1", 42, "thread-10") != base
|
||||
assert stable_hash("v2", 42, "thread-9") != base
|
||||
assert stable_hash("v1", 43, "thread-9") != base
|
||||
|
||||
|
||||
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
|
||||
"""Two tool lists with the same surface must hash identically.
|
||||
|
||||
The cache key MUST NOT change when the underlying ``BaseTool``
|
||||
instances are different Python objects (a fresh request constructs
|
||||
fresh tool instances every time). Hashing on ``(name, description)``
|
||||
keeps the cache hot across requests with identical tool surfaces.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FakeTool:
|
||||
name: str
|
||||
description: str
|
||||
|
||||
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
|
||||
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
|
||||
sig_a = tools_signature(
|
||||
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
sig_b = tools_signature(
|
||||
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
assert sig_a == sig_b, "tool order must not affect the signature"
|
||||
|
||||
# Adding a tool rotates the key.
|
||||
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
|
||||
sig_c = tools_signature(
|
||||
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
assert sig_c != sig_a
|
||||
|
||||
|
||||
def test_tools_signature_rotates_when_connector_set_changes() -> None:
|
||||
@dataclass
|
||||
class FakeTool:
|
||||
name: str
|
||||
description: str
|
||||
|
||||
tools = [FakeTool("a", "x")]
|
||||
base = tools_signature(
|
||||
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
added = tools_signature(
|
||||
tools,
|
||||
available_connectors=["NOTION", "SLACK"],
|
||||
available_document_types=["FILE"],
|
||||
)
|
||||
assert base != added, "adding a connector must rotate the cache key"
|
||||
|
||||
|
||||
def test_flags_signature_changes_when_flag_flips() -> None:
|
||||
@dataclass(frozen=True)
|
||||
class Flags:
|
||||
a: bool = True
|
||||
b: bool = False
|
||||
|
||||
base = flags_signature(Flags())
|
||||
flipped = flags_signature(Flags(b=True))
|
||||
assert base != flipped
|
||||
|
||||
|
||||
def test_system_prompt_hash_is_stable_and_distinct() -> None:
|
||||
p1 = "You are a helpful assistant."
|
||||
p2 = "You are a helpful assistant!" # one-character delta
|
||||
assert system_prompt_hash(p1) == system_prompt_hash(p1)
|
||||
assert system_prompt_hash(p1) != system_prompt_hash(p2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
builds = 0
|
||||
|
||||
async def builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is b, "cache must return the SAME object across hits"
|
||||
assert builds == 1, "builder must run exactly once"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_different_keys_get_different_instances() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k1", builder=builder)
|
||||
b = await cache.get_or_build("k2", builder=builder)
|
||||
assert a is not b
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_stale_entries_get_rebuilt() -> None:
|
||||
# ttl=0 means every read sees the entry as immediately stale.
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
|
||||
builds = 0
|
||||
|
||||
async def builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is not b, "stale entry must rebuild a fresh instance"
|
||||
assert builds == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_evicts_lru_when_full() -> None:
|
||||
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("a", builder=builder)
|
||||
_ = await cache.get_or_build("b", builder=builder)
|
||||
# Re-touch "a" so "b" is now the LRU victim.
|
||||
a_again = await cache.get_or_build("a", builder=builder)
|
||||
assert a_again is a
|
||||
# Inserting "c" should evict "b" (LRU), not "a".
|
||||
_ = await cache.get_or_build("c", builder=builder)
|
||||
assert cache.stats()["size"] == 2
|
||||
|
||||
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
|
||||
a_hit = await cache.get_or_build("a", builder=builder)
|
||||
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
|
||||
"""Two concurrent get_or_build calls on the same key must share one builder."""
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
build_started = asyncio.Event()
|
||||
builds = 0
|
||||
|
||||
async def slow_builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
build_started.set()
|
||||
# Yield control so the second waiter can race against us.
|
||||
await asyncio.sleep(0.05)
|
||||
return object()
|
||||
|
||||
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||
# Wait until the first builder has started, then race a second waiter.
|
||||
await build_started.wait()
|
||||
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||
|
||||
a, b = await asyncio.gather(task_a, task_b)
|
||||
assert a is b, "coalesced waiters must observe the same value"
|
||||
assert builds == 1, "concurrent cold misses must collapse to ONE build"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_does_not_store_failed_builds() -> None:
|
||||
"""A builder that raises must NOT poison the cache.
|
||||
|
||||
The next caller for the same key must run the builder again (not
|
||||
re-raise the cached exception).
|
||||
"""
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
attempts = 0
|
||||
|
||||
async def flaky_builder() -> object:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("transient")
|
||||
return object()
|
||||
|
||||
with pytest.raises(RuntimeError, match="transient"):
|
||||
await cache.get_or_build("k", builder=flaky_builder)
|
||||
|
||||
# Second call must retry — not re-raise the cached exception.
|
||||
value = await cache.get_or_build("k", builder=flaky_builder)
|
||||
assert value is not None
|
||||
assert attempts == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidate_drops_entry() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
assert cache.invalidate("k") is True
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is not b, "post-invalidation lookup must rebuild"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
|
||||
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
await cache.get_or_build("user:1:thread:1", builder=builder)
|
||||
await cache.get_or_build("user:1:thread:2", builder=builder)
|
||||
await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||
|
||||
removed = cache.invalidate_prefix("user:1:")
|
||||
assert removed == 2
|
||||
assert cache.stats()["size"] == 1
|
||||
|
||||
# The user:2 entry must still be hot (no rebuild).
|
||||
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||
assert survivor_value is not None
|
||||
|
|
@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"SURFSENSE_ENABLE_OTEL",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
|
||||
]:
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
flags = reload_for_tests()
|
||||
assert isinstance(flags, AgentFeatureFlags)
|
||||
assert flags.disable_new_agent_stack is False
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
assert flags.enable_context_editing is True
|
||||
assert flags.enable_compaction_v2 is True
|
||||
assert flags.enable_retry_after is True
|
||||
assert flags.enable_model_fallback is False
|
||||
assert flags.enable_model_call_limit is True
|
||||
assert flags.enable_tool_call_limit is True
|
||||
assert flags.enable_tool_call_repair is True
|
||||
assert flags.enable_doom_loop is True
|
||||
assert flags.enable_permission is True
|
||||
assert flags.enable_busy_mutex is True
|
||||
assert flags.enable_llm_tool_selector is False
|
||||
assert flags.enable_skills is True
|
||||
assert flags.enable_specialized_subagents is True
|
||||
assert flags.enable_kb_planner_runnable is True
|
||||
assert flags.enable_action_log is True
|
||||
assert flags.enable_revert_route is True
|
||||
assert flags.enable_stream_parity_v2 is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
assert flags.enable_otel is False
|
||||
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||
# ``db_session`` refactor landed). The companion gp-subagent share
|
||||
# flag stays default-off pending data on cold-miss frequency.
|
||||
assert flags.enable_agent_cache is True
|
||||
assert flags.enable_agent_cache_share_gp_subagent is False
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
|
||||
|
||||
def test_master_kill_switch_overrides_individual_flags(
|
||||
|
|
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
|||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||
}
|
||||
|
||||
# `enable_otel` is intentionally orthogonal — it does NOT count toward
|
||||
# ``any_new_middleware_enabled`` because OTel is observability-only and
|
||||
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
|
||||
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
|
||||
|
||||
for attr, env_name in flag_to_env.items():
|
||||
_clear_all(monkeypatch)
|
||||
monkeypatch.setenv(env_name, "true")
|
||||
monkeypatch.setenv(env_name, "false")
|
||||
flags = reload_for_tests()
|
||||
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
|
||||
if attr in counts_toward_middleware:
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
else:
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for ``FlattenSystemMessageMiddleware``.
|
||||
|
||||
The middleware exists to defend against Anthropic's "Found 5 cache_control
|
||||
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
|
||||
the system message and the OpenRouter→Anthropic adapter redistributes
|
||||
``cache_control`` across all of them. The flattening collapses every
|
||||
all-text system content list to a single string before the LLM call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.agents.new_chat.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
_flatten_text_blocks,
|
||||
_flattened_request,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_text_blocks — pure helper, the heart of the middleware.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenTextBlocks:
|
||||
def test_joins_text_blocks_with_double_newline(self) -> None:
|
||||
blocks = [
|
||||
{"type": "text", "text": "<surfsense base>"},
|
||||
{"type": "text", "text": "<filesystem section>"},
|
||||
{"type": "text", "text": "<skills section>"},
|
||||
]
|
||||
assert (
|
||||
_flatten_text_blocks(blocks)
|
||||
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
|
||||
)
|
||||
|
||||
def test_handles_single_text_block(self) -> None:
|
||||
blocks = [{"type": "text", "text": "only one"}]
|
||||
assert _flatten_text_blocks(blocks) == "only one"
|
||||
|
||||
def test_handles_empty_list(self) -> None:
|
||||
assert _flatten_text_blocks([]) == ""
|
||||
|
||||
def test_passes_through_bare_string_blocks(self) -> None:
|
||||
# LangChain content can mix bare strings and dict blocks.
|
||||
blocks = ["raw string", {"type": "text", "text": "dict block"}]
|
||||
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
|
||||
|
||||
def test_returns_none_for_image_block(self) -> None:
|
||||
# System messages with images are rare — but we never want to
|
||||
# silently lose the image payload by joining as text.
|
||||
blocks = [
|
||||
{"type": "text", "text": "look at this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
|
||||
]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_for_non_dict_non_str_block(self) -> None:
|
||||
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_when_text_field_missing(self) -> None:
|
||||
blocks = [{"type": "text"}] # no ``text`` key
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_when_text_is_not_string(self) -> None:
|
||||
blocks = [{"type": "text", "text": ["nested", "list"]}]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_drops_cache_control_from_inner_blocks(self) -> None:
|
||||
# The whole point: existing cache_control on inner blocks is
|
||||
# discarded so LiteLLM's ``cache_control_injection_points`` can
|
||||
# re-attach exactly one breakpoint after flattening.
|
||||
blocks = [
|
||||
{"type": "text", "text": "first"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "second",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
]
|
||||
flattened = _flatten_text_blocks(blocks)
|
||||
assert flattened == "first\n\nsecond"
|
||||
assert "cache_control" not in flattened # type: ignore[operator]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flattened_request — decides when to override and when to no-op.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(system_message: SystemMessage | None) -> Any:
|
||||
"""Build a minimal ModelRequest stub. We only need .system_message
|
||||
and .override(system_message=...) — the middleware never touches
|
||||
other fields.
|
||||
"""
|
||||
request = MagicMock()
|
||||
request.system_message = system_message
|
||||
|
||||
def override(**kwargs: Any) -> Any:
|
||||
new_request = MagicMock()
|
||||
new_request.system_message = kwargs.get(
|
||||
"system_message", request.system_message
|
||||
)
|
||||
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
|
||||
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
|
||||
return new_request
|
||||
|
||||
request.override = override
|
||||
return request
|
||||
|
||||
|
||||
class TestFlattenedRequest:
|
||||
def test_collapses_multi_block_system_to_string(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "<base>"},
|
||||
{"type": "text", "text": "<todo>"},
|
||||
{"type": "text", "text": "<filesystem>"},
|
||||
{"type": "text", "text": "<skills>"},
|
||||
{"type": "text", "text": "<subagents>"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
flattened = _flattened_request(request)
|
||||
|
||||
assert flattened is not None
|
||||
assert isinstance(flattened.system_message, SystemMessage)
|
||||
assert flattened.system_message.content == (
|
||||
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
|
||||
)
|
||||
|
||||
def test_no_op_for_string_content(self) -> None:
|
||||
sys = SystemMessage(content="already a string")
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_for_single_block_list(self) -> None:
|
||||
# One block already produces one breakpoint — no need to flatten.
|
||||
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_when_system_message_missing(self) -> None:
|
||||
request = _make_request(None)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_when_list_contains_non_text_block(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "look"},
|
||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_preserves_additional_kwargs_and_metadata(self) -> None:
|
||||
# Defensive: nothing in the current chain sets these on a system
|
||||
# message, but losing them silently when something does in the
|
||||
# future would be a regression. ``name`` in particular is the only
|
||||
# ``additional_kwargs`` field that ChatLiteLLM's
|
||||
# ``_convert_message_to_dict`` propagates onto the wire.
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
],
|
||||
additional_kwargs={"name": "surfsense_system", "x": 1},
|
||||
response_metadata={"tokens": 42},
|
||||
)
|
||||
sys.id = "sys-msg-1"
|
||||
request = _make_request(sys)
|
||||
|
||||
flattened = _flattened_request(request)
|
||||
assert flattened is not None
|
||||
assert flattened.system_message.content == "a\n\nb"
|
||||
assert flattened.system_message.additional_kwargs == {
|
||||
"name": "surfsense_system",
|
||||
"x": 1,
|
||||
}
|
||||
assert flattened.system_message.response_metadata == {"tokens": 42}
|
||||
assert flattened.system_message.id == "sys-msg-1"
|
||||
|
||||
def test_idempotent_when_run_twice(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
first = _flattened_request(request)
|
||||
assert first is not None
|
||||
|
||||
# Second pass on the already-flattened request should be a no-op.
|
||||
# We re-wrap in a request stub since the helper inspects
|
||||
# ``request.system_message.content``.
|
||||
second_request = _make_request(first.system_message)
|
||||
assert _flattened_request(second_request) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware integration — verify the handler sees a flattened request.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMiddlewareWrap:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_passes_flattened_request_to_handler(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "alpha"},
|
||||
{"type": "text", "text": "beta"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
|
||||
assert result == "ok"
|
||||
assert isinstance(captured["request"].system_message, SystemMessage)
|
||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_passes_through_when_already_string(self) -> None:
|
||||
sys = SystemMessage(content="just a string")
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
await mw.awrap_model_call(request, handler)
|
||||
|
||||
# Same request object: no override happened.
|
||||
assert captured["request"] is request
|
||||
|
||||
def test_sync_passes_flattened_request_to_handler(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "alpha"},
|
||||
{"type": "text", "text": "beta"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
result = mw.wrap_model_call(request, handler)
|
||||
|
||||
assert result == "ok"
|
||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||
|
||||
def test_sync_passes_through_when_no_system_message(self) -> None:
|
||||
request = _make_request(None)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
mw.wrap_model_call(request, handler)
|
||||
assert captured["request"] is request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression guard — pin the worst-case shape that triggered the
|
||||
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
|
||||
# downstream cache_control_injection_points can only place 1 breakpoint
|
||||
# on the system message regardless of provider redistribution quirks.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_regression_five_block_system_collapses_to_one_block() -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
|
||||
{"type": "text", "text": "<TodoListMiddleware section>"},
|
||||
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
|
||||
{"type": "text", "text": "<SkillsMiddleware section>"},
|
||||
{"type": "text", "text": "<SubAgentMiddleware section>"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
flattened = _flattened_request(request)
|
||||
|
||||
assert flattened is not None
|
||||
assert isinstance(flattened.system_message.content, str)
|
||||
# The exact join doesn't matter for the cache_control accounting —
|
||||
# only that there is exactly ONE content block when LiteLLM's
|
||||
# AnthropicCacheControlHook later targets ``role: system``.
|
||||
assert "<surfsense base" in flattened.system_message.content
|
||||
assert "<SubAgentMiddleware" in flattened.system_message.content
|
||||
|
||||
|
||||
def test_regression_human_message_not_modified() -> None:
|
||||
# Sanity: the middleware MUST NOT touch user messages — only the
|
||||
# system message. Multi-block user content is the path that carries
|
||||
# image attachments and would lose its image_url block on
|
||||
# accidental flatten.
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
]
|
||||
)
|
||||
user = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "look at this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
request.messages = [user]
|
||||
|
||||
flattened = _flattened_request(request)
|
||||
assert flattened is not None
|
||||
# System flattened to string …
|
||||
assert isinstance(flattened.system_message.content, str)
|
||||
# … user message is untouched (the helper does not even look at it).
|
||||
assert flattened.messages == [user]
|
||||
assert isinstance(user.content, list)
|
||||
assert len(user.content) == 2
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for ``apply_litellm_prompt_caching`` in
|
||||
r"""Tests for ``apply_litellm_prompt_caching`` in
|
||||
:mod:`app.agents.new_chat.prompt_caching`.
|
||||
|
||||
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
|
||||
|
|
@ -6,9 +6,12 @@ never activated for our LiteLLM stack) with LiteLLM-native multi-provider
|
|||
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
|
||||
``litellm.completion(...)``. The tests below pin its public contract:
|
||||
|
||||
1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
|
||||
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
|
||||
savings compound across multi-turn conversations on Anthropic-family
|
||||
providers.
|
||||
providers. ``index: 0`` is used (rather than ``role: system``) because
|
||||
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
||||
``state["messages"]`` and ``role: system`` would tag every one of
|
||||
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
||||
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
|
||||
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
|
||||
prompt-cache surface is available).
|
||||
|
|
@ -92,11 +95,28 @@ def test_sets_both_cache_control_injection_points_with_no_config() -> None:
|
|||
apply_litellm_prompt_caching(llm)
|
||||
|
||||
points = llm.model_kwargs["cache_control_injection_points"]
|
||||
assert {"location": "message", "role": "system"} in points
|
||||
assert {"location": "message", "index": 0} in points
|
||||
assert {"location": "message", "index": -1} in points
|
||||
assert len(points) == 2
|
||||
|
||||
|
||||
def test_does_not_inject_role_system_breakpoint() -> None:
|
||||
"""Regression: deliberately AVOID ``role: system`` so we don't tag
|
||||
every SystemMessage the deepagent ``before_agent`` injectors push
|
||||
into ``state["messages"]`` (priority, tree, memory, file-intent,
|
||||
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
|
||||
``cache_control`` cap and surfaces as
|
||||
``OpenrouterException: A maximum of 4 blocks with cache_control may
|
||||
be provided. Found N`` 400s.
|
||||
"""
|
||||
llm = _FakeLLM()
|
||||
apply_litellm_prompt_caching(llm)
|
||||
points = llm.model_kwargs["cache_control_injection_points"]
|
||||
assert all(p.get("role") != "system" for p in points), (
|
||||
f"Expected no role=system breakpoint, got: {points}"
|
||||
)
|
||||
|
||||
|
||||
def test_injection_points_set_for_anthropic_config() -> None:
|
||||
"""Anthropic-family configs need the marker — verify it lands."""
|
||||
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
|
||||
|
|
|
|||
|
|
@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
|
|||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
|
||||
# ── mentioned_document_ids cross-turn drain ────────────────────────────
|
||||
|
||||
|
||||
class TestKnowledgePriorityMentionDrain:
|
||||
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
|
||||
|
||||
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
|
||||
instance across turns of the same thread. ``mentioned_document_ids``
|
||||
can therefore enter the middleware via two paths:
|
||||
|
||||
1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
|
||||
seeded by the cache-miss build on turn 1.
|
||||
2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
|
||||
turn by the streaming task.
|
||||
|
||||
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
|
||||
on turn 2 would fall through to the closure (because ``[]`` is falsy in
|
||||
Python) and replay turn 1's mentions. This class pins down the
|
||||
correct behaviour: the runtime path is authoritative even when empty,
|
||||
and the closure is drained the first time the runtime path fires so
|
||||
no later turn can ever resurrect stale state.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_runtime(mention_ids: list[int]):
|
||||
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(mentioned_document_ids=mention_ids),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _planner_llm() -> "FakeLLM":
|
||||
# Planner returns a stable, non-recency plan so we always land in
|
||||
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
|
||||
# invoked alongside the main search).
|
||||
return FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "follow up question",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
|
||||
"""Turn 1 with mentions in BOTH closure and runtime context: the
|
||||
runtime path wins AND the closure is drained so a future turn
|
||||
cannot replay it.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what is in those docs?")]},
|
||||
runtime=self._make_runtime([1, 2, 3]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [[1, 2, 3]], (
|
||||
"runtime.context mentions must be the source of truth on turn 1"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == [], (
|
||||
"closure must be drained the first time the runtime path fires "
|
||||
"so no later turn can replay stale mentions"
|
||||
)
|
||||
|
||||
async def test_empty_runtime_context_does_not_replay_closure_mentions(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Regression: turn 2 with NO mentions must not surface turn 1's
|
||||
mentions from the constructor closure.
|
||||
|
||||
Before the fix, ``if ctx_mentions:`` treated an empty list as
|
||||
absent and fell through to ``elif self.mentioned_document_ids:``,
|
||||
replaying turn 1's mentions. This test pins down the corrected
|
||||
behaviour.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
# Simulate a cached middleware instance whose closure was seeded
|
||||
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
# Turn 2: streaming task supplies an EMPTY mention list (no
|
||||
# mentions on this follow-up turn).
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what about the next steps?")]},
|
||||
runtime=self._make_runtime([]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [], (
|
||||
"fetch_mentioned_documents must NOT be called when the runtime "
|
||||
"context says there are no mentions for this turn"
|
||||
)
|
||||
|
||||
async def test_legacy_path_fires_only_when_runtime_context_absent(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Backward-compat: if a caller doesn't supply runtime.context (old
|
||||
non-streaming code path), the closure-injected mentions are still
|
||||
honoured exactly once and then drained.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[7, 8],
|
||||
)
|
||||
|
||||
# First call: no runtime → legacy path uses the closure.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="initial question")]},
|
||||
runtime=None,
|
||||
)
|
||||
# Second call: still no runtime — closure already drained, so no replay.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="follow up")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert fetched_ids == [[7, 8]], (
|
||||
"legacy path must honour the closure exactly once and then drain it"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
|
||||
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
|
||||
|
||||
There is no DB column for ``supports_image_input`` on
|
||||
``NewLLMConfig`` — the value is resolved at the API boundary by
|
||||
``derive_supports_image_input`` so the new-chat selector / streaming
|
||||
task can read the same field shape regardless of source (BYOK vs YAML
|
||||
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
|
||||
user out of their own model choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _byok_row(
|
||||
*,
|
||||
id_: int,
|
||||
model_name: str,
|
||||
base_model: str | None = None,
|
||||
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
|
||||
custom_provider: str | None = None,
|
||||
) -> object:
|
||||
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
|
||||
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
|
||||
|
||||
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
|
||||
enum validator accepts it — same as the ORM row would carry."""
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
name=f"BYOK-{id_}",
|
||||
description=None,
|
||||
provider=provider,
|
||||
custom_provider=custom_provider,
|
||||
model_name=model_name,
|
||||
api_key="sk-byok",
|
||||
api_base=None,
|
||||
litellm_params={"base_model": base_model} if base_model else None,
|
||||
system_instructions="",
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
created_at=datetime.now(tz=UTC),
|
||||
search_space_id=42,
|
||||
user_id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_byok_known_vision_model_resolves_true():
|
||||
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
|
||||
True. The serialized row carries that value through to the
|
||||
``NewLLMConfigRead`` schema."""
|
||||
row = _byok_row(id_=1, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
assert serialized.id == 1
|
||||
assert serialized.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_serialize_byok_unknown_model_default_allows():
|
||||
"""Unknown / unmapped: default-allow. The streaming-task safety net
|
||||
is the actual block, and it requires LiteLLM to *explicitly* say
|
||||
text-only — so a brand new BYOK model should not be pre-judged."""
|
||||
row = _byok_row(
|
||||
id_=2,
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
provider=LiteLLMProvider.CUSTOM,
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_uses_base_model_when_present():
|
||||
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
|
||||
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
|
||||
helper must consult ``base_model`` first or unrecognised deployment
|
||||
ids would shadow the real capability."""
|
||||
row = _byok_row(
|
||||
id_=3,
|
||||
model_name="my-azure-deployment-id-no-litellm-knows-this",
|
||||
base_model="gpt-4o",
|
||||
provider=LiteLLMProvider.AZURE_OPENAI,
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_returns_pydantic_read_model():
|
||||
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
|
||||
the schema additions are guaranteed to be present in the API
|
||||
surface. This guards against a future regression where someone
|
||||
deletes the augmentation step and falls back to ORM passthrough."""
|
||||
from app.schemas import NewLLMConfigRead
|
||||
|
||||
row = _byok_row(id_=4, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
assert isinstance(serialized, NewLLMConfigRead)
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
"""Unit tests for ``is_premium`` derivation on the global image-gen and
|
||||
vision-LLM list endpoints.
|
||||
|
||||
Chat globals (``GET /global-llm-configs``) already emit
|
||||
``is_premium = (billing_tier == "premium")``. Image and vision did not,
|
||||
which made the new-chat ``model-selector`` render the Free/Premium badge
|
||||
on the Chat tab but skip it on the Image and Vision tabs (the selector
|
||||
keys its badge logic off ``is_premium``). These tests pin parity:
|
||||
|
||||
* YAML free entry → ``is_premium=False``
|
||||
* YAML premium entry → ``is_premium=True``
|
||||
* OpenRouter dynamic premium entry → ``is_premium=True``
|
||||
* Auto stub (always emitted when at least one config is present)
|
||||
→ ``is_premium=False``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_IMAGE_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "DALL-E 3",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "dall-e-3",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-Image 1 (premium)",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -20_001,
|
||||
"name": "google/gemini-2.5-flash-image (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
_VISION_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o Vision",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "Claude 3.5 Sonnet (premium)",
|
||||
"provider": "ANTHROPIC",
|
||||
"model_name": "claude-3-5-sonnet",
|
||||
"api_key": "sk-ant-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -30_001,
|
||||
"name": "openai/gpt-4o (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image generation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
|
||||
"""Each emitted config must carry ``is_premium`` derived server-side
|
||||
from ``billing_tier``. The Auto stub is always free.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub is always emitted when at least one global config exists,
|
||||
# and it must always declare itself free (Auto-mode billing-tier
|
||||
# surfacing is a separate follow-up).
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
# YAML free entry — ``is_premium=False``
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
# YAML premium entry — ``is_premium=True``
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
# OpenRouter dynamic premium entry — same field, same derivation
|
||||
assert by_id[-20_001]["is_premium"] is True
|
||||
assert by_id[-20_001]["billing_tier"] == "premium"
|
||||
|
||||
# Every emitted dict (including Auto) must have the field — never missing.
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
"""When there are no global configs at all, the endpoint emits an
|
||||
empty list (no Auto stub) — Auto mode would have nothing to route to.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
assert payload == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
assert by_id[-30_001]["is_premium"] is True
|
||||
assert by_id[-30_001]["billing_tier"] == "premium"
|
||||
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
assert payload == []
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on the chat global
|
||||
config endpoint (``GET /global-new-llm-configs``).
|
||||
|
||||
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
|
||||
|
||||
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
|
||||
loader for operator overrides, or by the OpenRouter integration from
|
||||
``architecture.input_modalities``) — wins.
|
||||
2. ``derive_supports_image_input`` helper — default-allow on unknown
|
||||
models, only False when LiteLLM / OR modalities are definitive.
|
||||
|
||||
The flag is purely informational at the API boundary. The streaming
|
||||
task safety net (``is_known_text_only_chat_model``) is the actual block,
|
||||
and it requires LiteLLM to *explicitly* mark the model as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (explicit true)",
|
||||
"description": "vision-capable, explicit YAML override",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "DeepSeek V3 (explicit false)",
|
||||
"description": "OpenRouter dynamic — modality-derived false",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "deepseek/deepseek-v3.2-exp",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": False,
|
||||
},
|
||||
{
|
||||
"id": -10_010,
|
||||
"name": "Unannotated GPT-4o",
|
||||
"description": "no flag set — resolver should derive True via LiteLLM",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
# supports_image_input intentionally absent
|
||||
},
|
||||
{
|
||||
"id": -10_011,
|
||||
"name": "Unannotated unknown model",
|
||||
"description": "unmapped — default-allow True",
|
||||
"provider": "CUSTOM",
|
||||
"custom_provider": "brand_new_proxy",
|
||||
"model_name": "brand-new-model-x9",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
|
||||
"""Each emitted chat config carries ``supports_image_input`` as a
|
||||
bool. Explicit values win; unannotated entries are resolved via the
|
||||
helper (default-allow True)."""
|
||||
from app.config import config
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
|
||||
|
||||
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub: optimistic True so the user can keep Auto selected with
|
||||
# vision-capable deployments somewhere in the pool.
|
||||
assert 0 in by_id, "Auto stub should be emitted when configs exist"
|
||||
assert by_id[0]["supports_image_input"] is True
|
||||
assert by_id[0]["is_auto_mode"] is True
|
||||
|
||||
# Explicit True is preserved.
|
||||
assert by_id[-1]["supports_image_input"] is True
|
||||
|
||||
# Explicit False is preserved (the exact failure mode the safety net
|
||||
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
|
||||
# endpoints found that support image input").
|
||||
assert by_id[-2]["supports_image_input"] is False
|
||||
|
||||
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
|
||||
assert by_id[-10_010]["supports_image_input"] is True
|
||||
|
||||
# Unknown / unmapped model: default-allow rather than pre-judge.
|
||||
assert by_id[-10_011]["supports_image_input"] is True
|
||||
|
||||
for cfg in payload:
|
||||
assert "supports_image_input" in cfg, (
|
||||
f"supports_image_input missing from {cfg.get('id')}"
|
||||
)
|
||||
assert isinstance(cfg["supports_image_input"], bool)
|
||||
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Unit tests for the image-generation route's billing-resolution helper.
|
||||
|
||||
End-to-end "POST /image-generations returns 402" coverage requires the
|
||||
integration harness (real DB, real auth) and lives in
|
||||
``tests/integration/document_upload/`` alongside the other quota tests.
|
||||
This unit test focuses on the new ``_resolve_billing_for_image_gen``
|
||||
helper which:
|
||||
|
||||
* Returns ``free`` for Auto mode, even when premium configs exist
|
||||
(Auto-mode billing-tier surfacing is a follow-up).
|
||||
* Returns ``free`` for user-owned BYOK configs (positive IDs).
|
||||
* Returns the global config's ``billing_tier`` for negative IDs.
|
||||
* Honours the per-config ``quota_reserve_micros`` override when present.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_auto_mode(monkeypatch):
|
||||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, # Not consumed on this code path.
|
||||
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space=search_space,
|
||||
)
|
||||
assert tier == "free"
|
||||
assert model == "auto"
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"quota_reserve_micros": 75_000,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
|
||||
# Premium with override.
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=-1, search_space=search_space
|
||||
)
|
||||
assert tier == "premium"
|
||||
assert model == "openai/gpt-image-1"
|
||||
assert reserve == 75_000
|
||||
|
||||
# Free, no override → falls back to default.
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=-2, search_space=search_space
|
||||
)
|
||||
assert tier == "free"
|
||||
# Provider-prefixed model string for OpenRouter.
|
||||
assert "google/gemini-2.5-flash-image" in model
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_user_owned_byok_is_free():
|
||||
"""User-owned BYOK configs (positive IDs) cost the user nothing on
|
||||
our side — they pay the provider directly. Always free.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=42, search_space=search_space
|
||||
)
|
||||
assert tier == "free"
|
||||
assert model == "user_byok"
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
||||
"""When the request omits ``image_generation_config_id``, the helper
|
||||
must consult the search space's default — so a search space pinned
|
||||
to a premium global config still gates new requests by quota.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -7,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=-7)
|
||||
(
|
||||
tier,
|
||||
model,
|
||||
_reserve,
|
||||
) = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=None, search_space=search_space
|
||||
)
|
||||
assert tier == "premium"
|
||||
assert model == "openai/gpt-image-1"
|
||||
|
|
@ -0,0 +1,436 @@
|
|||
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||
|
||||
Validates the resolver used by Celery podcast/video tasks to compute
|
||||
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||
prevents Auto-mode podcast/video from leaking premium credit.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||
global → returns ``("premium", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||
global → returns ``("free", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||
→ always ``"free"``.
|
||||
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||
hitting the pin service.
|
||||
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||
``billing_tier``.
|
||||
* Positive id (user BYOK) → always ``"free"``.
|
||||
* Search space not found → raises ``ValueError``.
|
||||
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Tiny AsyncSession stub.
|
||||
|
||||
``responses`` is a list of objects to return from successive
|
||||
``execute()`` calls (in order). The resolver makes at most two
|
||||
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||
lookup), so two queued responses cover the matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: list):
|
||||
self._responses = list(responses)
|
||||
|
||||
async def execute(self, _stmt):
|
||||
if not self._responses:
|
||||
return _FakeExecResult(None)
|
||||
return _FakeExecResult(self._responses.pop(0))
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakePinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str = "premium"
|
||||
from_existing_pin: bool = False
|
||||
|
||||
|
||||
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=42,
|
||||
agent_llm_id=agent_llm_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_byok_config(
|
||||
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
model_name=model_name,
|
||||
litellm_params={"base_model": base_model} if base_model else {},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||
resolver returns ``("premium", <base_model>)``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
# Mock the pin service to return a concrete premium config id.
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
assert selected_llm_config_id == 0
|
||||
assert thread_id == 99
|
||||
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||
|
||||
# Mock global config lookup to return a premium entry.
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -1:
|
||||
return {
|
||||
"id": -1,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
return None
|
||||
|
||||
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||
# imported names resolve to our fakes.
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||
"""Auto + thread → pin returns negative-id free config → resolver
|
||||
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||
out-of-credit users (graceful degradation)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -3:
|
||||
return {
|
||||
"id": -3,
|
||||
"model_name": "openrouter/free-model",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||
}
|
||||
return None
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/free-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||
returns ``("free", ...)`` (BYOK is always free per
|
||||
``AgentConfig.from_new_llm_config``)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||
)
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3-haiku"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||
the pin service. Forward-compat fallback for any future direct-API
|
||||
entrypoint that doesn't have a chat thread."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||
"""If the pin service raises ``ValueError`` (thread missing /
|
||||
mismatched search space), the resolver should log and return free
|
||||
rather than killing the whole task."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(*args, **kwargs):
|
||||
raise ValueError("thread missing")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||
return its ``billing_tier``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "openrouter/some-free",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/some-free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||
"""When the global config has no ``litellm_params.base_model``, the
|
||||
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "fallback-model",
|
||||
"billing_tier": "premium",
|
||||
# No litellm_params.
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
_, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert tier == "premium"
|
||||
assert base_model == "fallback-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_is_always_free():
|
||||
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||
regardless of underlying provider tier."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||
"""If the BYOK config row is missing/deleted but the search space still
|
||||
points at it, the resolver still returns free (no debit) with an empty
|
||||
base_model — billable_call's premium path is skipped, no harm done."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_space_not_found_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
session = _FakeSession([None])
|
||||
|
||||
with pytest.raises(ValueError, match="Search space"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_llm_id_none_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||
|
||||
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||
|
|
@ -101,11 +101,116 @@ async def test_auto_first_turn_pins_one_model(monkeypatch):
|
|||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id in {-1, -2}
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
|
||||
assert session.commit_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_eligible_auto_prefers_premium_over_free(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-free",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "free",
|
||||
"quality_score": 100,
|
||||
},
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-prem",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
"quality_score": 10,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_eligible_auto_prefers_azure_gpt_5_4(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5.1",
|
||||
"api_key": "k1",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 100,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5.4",
|
||||
"api_key": "k2",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 10,
|
||||
},
|
||||
{
|
||||
"id": -3,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-5.4",
|
||||
"api_key": "k3",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "B",
|
||||
"quality_score": 100,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id="00000000-0000-0000-0000-000000000001",
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.resolved_tier == "premium"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_reuses_existing_pin(monkeypatch):
|
||||
from app.config import config
|
||||
|
|
@ -361,12 +466,12 @@ async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
|||
],
|
||||
)
|
||||
|
||||
async def _allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
async def _blocked(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_allowed,
|
||||
_blocked,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,286 @@
|
|||
"""Image-aware extension of the Auto-pin resolver.
|
||||
|
||||
When the current chat turn carries an ``image_url`` block, the pin
|
||||
resolver must:
|
||||
|
||||
1. Filter the candidate pool to vision-capable cfgs so a freshly
|
||||
selected pin can never be text-only.
|
||||
2. Treat any existing pin whose capability is False as invalid (force
|
||||
re-pin), even when it would otherwise be reused as the thread's
|
||||
stable model.
|
||||
3. Raise ``ValueError`` (mapped to the friendly
|
||||
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
|
||||
task) when no vision-capable cfg is available — instead of silently
|
||||
pinning text-only and 404-ing at the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
clear_healthy,
|
||||
clear_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_caches():
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
self._thread = thread
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _thread(*, pinned: int | None = None):
|
||||
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
|
||||
|
||||
|
||||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"vision-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
"supports_image_input": True,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"text-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
# Higher quality than the vision cfgs — so a bug that ignores
|
||||
# the image flag would surface as the resolver picking this one.
|
||||
"supports_image_input": False,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
async def _premium_allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
# The thread should be pinned to the vision cfg even though the
|
||||
# text-only cfg has a higher quality score.
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
|
||||
"""An existing text-only pin must be invalidated when the next turn
|
||||
requires image input. The non-image path would happily reuse it."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
|
||||
"""If the thread is already pinned to a vision-capable cfg, reuse it
|
||||
— same as the non-image path. Image-aware filtering must not force
|
||||
spurious re-pins."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-2))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
|
||||
"""The friendly-error path: no vision-capable cfg in the pool -> raise
|
||||
``ValueError`` whose message contains ``vision-capable`` so the
|
||||
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _text_only_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="vision-capable"):
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
|
||||
"""Regression guard: the image flag must default False and not affect
|
||||
a normal text-only turn — text-only cfgs remain selectable."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
||||
"""A YAML cfg that omits ``supports_image_input`` falls through to
|
||||
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
|
||||
that returns True, so the cfg should be a valid candidate."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
cfg_unannotated_vision = {
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o", # known vision model in LiteLLM map
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 80,
|
||||
# NOTE: no supports_image_input key
|
||||
}
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
559
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
559
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
|
|
@ -0,0 +1,559 @@
|
|||
"""Unit tests for the ``billable_call`` async context manager.
|
||||
|
||||
Covers the per-call premium-credit lifecycle for image generation and
|
||||
vision LLM extraction:
|
||||
|
||||
* Free configs bypass reserve/finalize but still write an audit row.
|
||||
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
|
||||
route layer).
|
||||
* Successful premium calls reserve, yield the accumulator, then finalize
|
||||
with the LiteLLM-reported actual cost — and write an audit row.
|
||||
* Failed premium calls release the reservation so credit isn't leaked.
|
||||
* All quota DB ops happen inside their OWN ``shielded_async_session``,
|
||||
isolating them from the caller's transaction (issue A).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeQuotaResult:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
allowed: bool,
|
||||
used: int = 0,
|
||||
limit: int = 5_000_000,
|
||||
remaining: int = 5_000_000,
|
||||
) -> None:
|
||||
self.allowed = allowed
|
||||
self.used = used
|
||||
self.limit = limit
|
||||
self.remaining = remaining
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stub — record commits for assertion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.added: list[Any] = []
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
s = _FakeSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
|
||||
_SESSIONS_USED: list[_FakeSession] = []
|
||||
|
||||
|
||||
def _patch_isolation_layer(
|
||||
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
|
||||
):
|
||||
"""Wire fake reserve/finalize/release/session helpers."""
|
||||
_SESSIONS_USED.clear()
|
||||
reserve_calls: list[dict[str, Any]] = []
|
||||
finalize_calls: list[dict[str, Any]] = []
|
||||
release_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
|
||||
reserve_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"reserve_micros": reserve_micros,
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
return reserve_result
|
||||
|
||||
async def _fake_finalize(
|
||||
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||
):
|
||||
if finalize_exc is not None:
|
||||
raise finalize_exc
|
||||
finalize_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"actual_micros": actual_micros,
|
||||
"reserved_micros": reserved_micros,
|
||||
}
|
||||
)
|
||||
return finalize_result or _FakeQuotaResult(allowed=True)
|
||||
|
||||
async def _fake_release(*, db_session, user_id, reserved_micros):
|
||||
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
|
||||
|
||||
record_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_record(session, **kwargs):
|
||||
record_calls.append(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_reserve",
|
||||
_fake_reserve,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_finalize",
|
||||
_fake_finalize,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_release",
|
||||
_fake_release,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_fake_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"reserve": reserve_calls,
|
||||
"finalize": finalize_calls,
|
||||
"release": release_calls,
|
||||
"record": record_calls,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# Simulate a captured cost — the accumulator is fed by the LiteLLM
|
||||
# callback in real life, here we add it manually.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Free still audits.
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["record"][0]["usage_type"] == "image_generation"
|
||||
assert spies["record"][0]["cost_micros"] == 37_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "image_generation"
|
||||
assert err.used_micros == 5_000_000
|
||||
assert err.limit_micros == 5_000_000
|
||||
assert err.remaining_micros == 0
|
||||
# Reserve was attempted, but no finalize/release on a denied reserve
|
||||
# — the reservation never actually held credit.
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Denied premium calls do NOT create an audit row (no work happened).
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# LiteLLM callback would normally fill this — simulate $0.04 image.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["reserve"][0]["reserve_micros"] == 50_000
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert spies["finalize"][0]["actual_micros"] == 40_000
|
||||
assert spies["finalize"][0]["reserved_micros"] == 50_000
|
||||
assert spies["release"] == []
|
||||
# And audit row written with the actual debited cost.
|
||||
assert spies["record"][0]["cost_micros"] == 40_000
|
||||
# Each quota op opened its OWN session — proves session isolation.
|
||||
assert len(_SESSIONS_USED) >= 3
|
||||
# Sessions used should each have committed (or be the audit one which commits).
|
||||
for _s in _SESSIONS_USED:
|
||||
# finalize/reserve happen via TokenQuotaService.* which we stub —
|
||||
# they don't actually call commit on our fake session, but the
|
||||
# audit session does. We just assert >=1 session committed.
|
||||
pass
|
||||
assert any(s.committed for s in _SESSIONS_USED)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_failure_releases_reservation(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _ProviderError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_ProviderError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
raise _ProviderError("OpenRouter 503")
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
# Failure path: release the held reservation.
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["release"][0]["reserved_micros"] == 50_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
||||
"""When ``quota_reserve_micros_override`` is None we fall back to
|
||||
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
|
||||
Vision LLM calls take this path (token-priced models).
|
||||
"""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
captured_estimator_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_estimate(*, base_model, quota_reserve_tokens):
|
||||
captured_estimator_calls.append(
|
||||
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
|
||||
)
|
||||
return 12_345
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.estimate_call_reserve_micros",
|
||||
_fake_estimate,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
usage_type="vision_extraction",
|
||||
):
|
||||
pass
|
||||
|
||||
assert captured_estimator_calls == [
|
||||
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
|
||||
]
|
||||
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
|
||||
from app.services.billable_calls import BillingSettlementError, billable_call
|
||||
|
||||
class _FinalizeError(RuntimeError):
|
||||
pass
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(allowed=True),
|
||||
finalize_exc=_FinalizeError("db finalize failed"),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(BillingSettlementError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _HangingCommitSession(_FakeSession):
|
||||
async def commit(self) -> None:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _hanging_session_factory():
|
||||
s = _HangingCommitSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
billable_session_factory=_hanging_session_factory,
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["release"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_audit_failure_is_best_effort(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
async def _failing_record(_session, **_kwargs):
|
||||
raise RuntimeError("audit insert failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_failing_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
async with billable_call(
|
||||
user_id=uuid4(),
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Podcast / video-presentation usage_type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
||||
"""Free podcast configs must skip reserve/finalize but still emit a
|
||||
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
|
||||
have full audit coverage of free-tier agent runs."""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openrouter/some-free-model",
|
||||
quota_reserve_micros_override=200_000,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=99,
|
||||
call_details={"podcast_id": 7, "title": "Test Podcast"},
|
||||
) as acc:
|
||||
# Two transcript LLM calls aggregated into one accumulator.
|
||||
acc.add(
|
||||
model="openrouter/some-free-model",
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=8000,
|
||||
total_tokens=9500,
|
||||
cost_micros=0,
|
||||
call_kind="chat",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
|
||||
assert len(spies["record"]) == 1
|
||||
row = spies["record"][0]
|
||||
assert row["usage_type"] == "podcast_generation"
|
||||
assert row["thread_id"] is None
|
||||
assert row["search_space_id"] == 42
|
||||
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
|
||||
"""Premium video-presentation runs that hit a denied reservation must
|
||||
raise ``QuotaInsufficientError`` *before* the graph runs and must not
|
||||
emit an audit row (no work happened)."""
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="gpt-5.4",
|
||||
quota_reserve_micros_override=1_000_000,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=99,
|
||||
call_details={"video_presentation_id": 12, "title": "Test Video"},
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "video_presentation_generation"
|
||||
assert err.remaining_micros == 500_000
|
||||
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
assert spies["record"] == []
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
"""Defense-in-depth: image-gen call sites must not let an empty
|
||||
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
|
||||
|
||||
The bug repro: an OpenRouter image-gen config ships
|
||||
``api_base=""``. The pre-fix call site in
|
||||
``image_generation_routes._execute_image_generation`` did
|
||||
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
|
||||
silently dropped the empty string. LiteLLM then fell back to
|
||||
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
|
||||
and OpenRouter's ``image_generation/transformation`` appended
|
||||
``/chat/completions`` to it → 404 ``Resource not found``.
|
||||
|
||||
This test pins the post-fix behaviour: with an empty ``api_base`` in
|
||||
the config, the call site MUST set ``api_base`` to OpenRouter's public
|
||||
URL instead of leaving it unset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||
"""The global-config branch (``config_id < 0``) of
|
||||
``_execute_image_generation`` must apply the resolver and pin
|
||||
``api_base`` to OpenRouter when the config ships an empty string.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "", # the original bug shape
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
|
||||
|
||||
image_gen = MagicMock()
|
||||
image_gen.image_generation_config_id = cfg["id"]
|
||||
image_gen.prompt = "test"
|
||||
image_gen.n = 1
|
||||
image_gen.quality = None
|
||||
image_gen.size = None
|
||||
image_gen.style = None
|
||||
image_gen.response_format = None
|
||||
image_gen.model = None
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
session = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_image_gen_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"aimage_generation",
|
||||
side_effect=fake_aimage_generation,
|
||||
),
|
||||
):
|
||||
await image_generation_routes._execute_image_generation(
|
||||
session=session, image_gen=image_gen, search_space=search_space
|
||||
)
|
||||
|
||||
# The whole point of the fix: even with empty ``api_base`` in the
|
||||
# config, we forward OpenRouter's public URL so the call doesn't
|
||||
# inherit an Azure endpoint.
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.new_chat.tools import generate_image as gi_module
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
response = MagicMock()
|
||||
response.model_dump.return_value = {
|
||||
"data": [{"url": "https://example.com/x.png"}]
|
||||
}
|
||||
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
|
||||
return response
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
|
||||
session_cm = AsyncMock()
|
||||
session = AsyncMock()
|
||||
session_cm.__aenter__.return_value = session
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalars.return_value = scalars
|
||||
session.execute.return_value = exec_result
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
|
||||
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
|
||||
async def _refresh(obj):
|
||||
obj.id = 1
|
||||
|
||||
session.refresh.side_effect = _refresh
|
||||
|
||||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
patch.object(
|
||||
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
|
||||
),
|
||||
):
|
||||
tool = gi_module.create_generate_image_tool(
|
||||
search_space_id=1, db_session=MagicMock()
|
||||
)
|
||||
await tool.ainvoke({"prompt": "a cat", "n": 1})
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
|
||||
"""The Auto-mode router pool must also resolve ``api_base`` when an
|
||||
OpenRouter config ships an empty string. The deployment dict is fed
|
||||
straight to ``litellm.Router``, so a missing ``api_base`` would
|
||||
leak the same way as the direct call sites.
|
||||
"""
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
deployment = ImageGenRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
|
@ -214,3 +214,167 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
|
|||
assert "openai/gpt-4o" in model_names
|
||||
assert "openai/dall-e" not in model_names
|
||||
assert "openai/completion-only" not in model_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_filters_by_image_output():
|
||||
"""Only models with ``output_modalities`` containing ``image`` are emitted.
|
||||
Tool-calling and context filters are intentionally NOT applied — image
|
||||
generation has nothing to do with tool calls and context windows.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# Pure image-gen model (small context, no tools — should still emit).
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Multi-modal: text+image output (should still emit).
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"architecture": {"output_modalities": ["text", "image"]},
|
||||
"context_length": 1_000_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
|
||||
},
|
||||
# Pure text model — must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = {c["model_name"] for c in cfgs}
|
||||
assert "openai/gpt-image-1" in model_names
|
||||
assert "google/gemini-2.5-flash-image" in model_names
|
||||
assert "openai/gpt-4o" not in model_names
|
||||
|
||||
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||
for c in cfgs:
|
||||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't 404 against an inherited Azure endpoint.
|
||||
assert c["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
"""Image configs use a different id_offset (-20000) so their negative
|
||||
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
}
|
||||
]
|
||||
# Don't pass image_id_offset → use the module default (-20000).
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||
(no text out) and text-only (no image in) models are excluded.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["image"],
|
||||
},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Pure text model (no image in). Must NOT emit.
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
names = {c["model_name"] for c in cfgs}
|
||||
assert names == {"openai/gpt-4o"}
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["billing_tier"] == "premium"
|
||||
# Pricing carried inline so pricing_registration can register vision
|
||||
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||
# is cleared.
|
||||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't inherit an Azure endpoint.
|
||||
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
"""A small-context vision model that doesn't advertise tool calling is
|
||||
still a valid vision LLM for "describe this image" prompts. The chat
|
||||
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||
NOT be applied to vision emission.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "tiny/vision-mini",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": [], # no tools
|
||||
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||
}
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert len(cfgs) == 1
|
||||
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,447 @@
|
|||
"""Pricing registration unit tests.
|
||||
|
||||
The pricing-registration module is what makes ``response_cost`` populate
|
||||
correctly for OpenRouter dynamic models and operator-defined Azure
|
||||
deployments — both of which LiteLLM doesn't natively know about. The tests
|
||||
exercise:
|
||||
|
||||
* The alias generators emit every shape that LiteLLM's cost-callback might
|
||||
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
|
||||
``provider/base_model``, ``provider/model_name``, plus the special
|
||||
``azure_openai`` → ``azure`` normalisation).
|
||||
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
|
||||
with the right alias set and pricing values per provider.
|
||||
* Configs without a resolvable pair of cost values are skipped — never
|
||||
registered as zero, since that would override pricing LiteLLM might
|
||||
already know natively.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_openrouter_alias_set_includes_prefixed_and_bare():
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
|
||||
assert aliases == [
|
||||
"openrouter/anthropic/claude-3-5-sonnet",
|
||||
"anthropic/claude-3-5-sonnet",
|
||||
]
|
||||
|
||||
|
||||
def test_openrouter_alias_set_dedupes():
|
||||
"""If the model id is already prefixed with ``openrouter/``, the alias
|
||||
set must not contain duplicates that would re-register the same key
|
||||
twice.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("openrouter/foo")
|
||||
# The bare and prefixed variants compute to the same string here, so we
|
||||
# at minimum require uniqueness.
|
||||
assert len(aliases) == len(set(aliases))
|
||||
|
||||
|
||||
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
|
||||
"""``azure_openai`` (our YAML provider slug) must register under
|
||||
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
|
||||
(which uses provider ``azure``) finds the pricing too.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="gpt-5.4",
|
||||
base_model="gpt-5.4",
|
||||
)
|
||||
assert "gpt-5.4" in aliases
|
||||
assert "azure_openai/gpt-5.4" in aliases
|
||||
assert "azure/gpt-5.4" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
|
||||
"""When ``model_name`` differs from ``base_model`` (operator labelled a
|
||||
deployment), both must appear in the alias set since either may surface
|
||||
in callbacks depending on the call path.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="OPENAI",
|
||||
model_name="my-deployment-label",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
assert "gpt-4o" in aliases
|
||||
assert "openai/gpt-4o" in aliases
|
||||
assert "my-deployment-label" in aliases
|
||||
assert "openai/my-deployment-label" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="",
|
||||
model_name="foo",
|
||||
base_model="bar",
|
||||
)
|
||||
assert "bar" in aliases
|
||||
assert "foo" in aliases
|
||||
assert all("/" not in a for a in aliases)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_pricing_from_global_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RegistrationSpy:
|
||||
"""Captures the dicts passed to ``litellm.register_model``.
|
||||
|
||||
Many calls may go through; we just record them all and let tests assert
|
||||
against the union.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
def __call__(self, payload: dict[str, Any]) -> None:
|
||||
self.calls.append(payload)
|
||||
|
||||
@property
|
||||
def all_keys(self) -> set[str]:
|
||||
keys: set[str] = set()
|
||||
for payload in self.calls:
|
||||
keys.update(payload.keys())
|
||||
return keys
|
||||
|
||||
|
||||
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
|
||||
spy = _RegistrationSpy()
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
spy,
|
||||
raising=False,
|
||||
)
|
||||
return spy
|
||||
|
||||
|
||||
def _patch_openrouter_pricing(
|
||||
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
|
||||
) -> None:
|
||||
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
|
||||
|
||||
class _Stub:
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
return mapping
|
||||
|
||||
class _StubService:
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> _Stub:
|
||||
return _Stub()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
|
||||
_StubService,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_models_register_under_aliases(monkeypatch):
|
||||
"""An OpenRouter config whose ``model_name`` is in the cached raw
|
||||
pricing map is registered under both ``openrouter/X`` and bare ``X``.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
# Costs are float-converted from the raw OpenRouter strings.
|
||||
payload = spy.calls[0]
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"input_cost_per_token"
|
||||
] == pytest.approx(3e-6)
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"output_cost_per_token"
|
||||
] == pytest.approx(15e-6)
|
||||
assert (
|
||||
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
|
||||
== "openrouter"
|
||||
)
|
||||
|
||||
|
||||
def test_yaml_override_registers_under_alias_set(monkeypatch):
|
||||
"""Operator-declared ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` on a YAML config registers under every
|
||||
alias the YAML alias generator produces — including the ``azure/``
|
||||
normalisation for ``azure_openai`` providers.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5.4",
|
||||
"litellm_params": {
|
||||
"base_model": "gpt-5.4",
|
||||
"input_cost_per_token": 2e-6,
|
||||
"output_cost_per_token": 8e-6,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
keys = spy.all_keys
|
||||
assert "gpt-5.4" in keys
|
||||
assert "azure_openai/gpt-5.4" in keys
|
||||
assert "azure/gpt-5.4" in keys
|
||||
|
||||
payload = spy.calls[0]
|
||||
entry = payload["gpt-5.4"]
|
||||
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
|
||||
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
|
||||
assert entry["litellm_provider"] == "azure"
|
||||
|
||||
|
||||
def test_no_override_means_no_registration(monkeypatch):
|
||||
"""A YAML config that *omits* both pricing fields must NOT be registered
|
||||
— registering as zero would override LiteLLM's native pricing for the
|
||||
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
|
||||
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {"base_model": "gpt-4o"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
|
||||
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
|
||||
configured model (network blip during refresh, model added later, etc.),
|
||||
we skip it rather than registering zero pricing.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
||||
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
|
||||
must not abort registration of the remaining configs.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
|
||||
successful_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _maybe_fail(payload: dict[str, Any]) -> None:
|
||||
if any(k in failing_keys for k in payload):
|
||||
raise RuntimeError("boom")
|
||||
successful_calls.append(payload)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
_maybe_fail,
|
||||
raising=False,
|
||||
)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "custom-deployment",
|
||||
"litellm_params": {
|
||||
"base_model": "custom-deployment",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 2e-6,
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
# The good config still registered.
|
||||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||
"""``register_pricing_from_global_configs`` walks
|
||||
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||
calls (during indexing) bill correctly. Vision configs use the same
|
||||
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||
registered here (handled via ``response_cost`` in LiteLLM).
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||
)
|
||||
|
||||
# No chat configs — only vision. Proves the vision walk is a separate
|
||||
# iteration, not piggy-backed on the chat list.
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||
assert payload_value["mode"] == "chat"
|
||||
assert payload_value["litellm_provider"] == "openrouter"
|
||||
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
|
||||
|
||||
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||
"""If the OpenRouter pricing cache misses a vision model (different
|
||||
catalogue surface), the vision walk falls back to inline
|
||||
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 4e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Unit tests for the shared ``api_base`` resolver.
|
||||
|
||||
The cascade exists so vision and image-gen call sites can't silently
|
||||
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
|
||||
when an OpenRouter / Groq / etc. config ships an empty string. See
|
||||
``provider_api_base`` module docstring for the original repro
|
||||
(OpenRouter image-gen 404-ing against an Azure endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_api_base import (
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_config_value_wins_over_defaults():
|
||||
"""A non-empty config value is always returned verbatim, even when the
|
||||
provider has a default — the operator gets the last word."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="https://my-openrouter-mirror.example.com/v1",
|
||||
)
|
||||
assert result == "https://my-openrouter-mirror.example.com/v1"
|
||||
|
||||
|
||||
def test_provider_key_default_when_config_missing():
|
||||
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
|
||||
base URL — the provider-key map must take precedence over the prefix
|
||||
map so DeepSeek requests don't go to OpenAI."""
|
||||
result = resolve_api_base(
|
||||
provider="DEEPSEEK",
|
||||
provider_prefix="openai",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_provider_prefix_default_when_no_key_default():
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_unknown_provider_returns_none():
|
||||
"""When neither map matches we return ``None`` so the caller can let
|
||||
LiteLLM apply its own provider-integration default (Azure deployment
|
||||
URL, custom-provider URL, etc.)."""
|
||||
result = resolve_api_base(
|
||||
provider="SOMETHING_NEW",
|
||||
provider_prefix="something_new",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_empty_string_config_treated_as_missing():
|
||||
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
|
||||
and downstream call sites use ``if cfg.get("api_base"):`` — empty
|
||||
strings are falsy in Python but the cascade has to step in anyway."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_whitespace_only_config_treated_as_missing():
|
||||
"""A config value of ``" "`` is a configuration mistake — treat it
|
||||
as missing instead of forwarding whitespace to LiteLLM (which would
|
||||
almost certainly 404)."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=" ",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_provider_case_insensitive():
|
||||
"""Some call sites pass the provider lowercase (DB enum value), others
|
||||
uppercase (YAML key). Both must resolve."""
|
||||
upper = resolve_api_base(
|
||||
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
lower = resolve_api_base(
|
||||
provider="deepseek", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_all_inputs_none_returns_none():
|
||||
assert (
|
||||
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
|
||||
is None
|
||||
)
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Unit tests for the shared chat-image capability resolver.
|
||||
|
||||
Two resolvers, two intents:
|
||||
|
||||
- ``derive_supports_image_input`` — best-effort True for the catalog and
|
||||
selector. Default-allow on unknown / unmapped models. The streaming
|
||||
task safety net never sees this value directly.
|
||||
|
||||
- ``is_known_text_only_chat_model`` — strict opt-out for the safety net.
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False``. Anything else (missing key, exception,
|
||||
True) returns False so the request flows through to the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import (
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — OpenRouter modalities path (authoritative)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_or_modalities_with_image_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="openai/gpt-4o",
|
||||
openrouter_input_modalities=["text", "image"],
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_text_only_returns_false():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="deepseek/deepseek-v3.2-exp",
|
||||
openrouter_input_modalities=["text"],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_empty_list_returns_false():
|
||||
"""OR explicitly publishing an empty modality list is a definitive
|
||||
'no inputs at all' signal — treat as False rather than falling back
|
||||
to LiteLLM."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="weird/empty-modalities",
|
||||
openrouter_input_modalities=[],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_none_falls_through_to_litellm():
|
||||
"""``None`` (missing key) is *not* a definitive signal — fall through
|
||||
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
openrouter_input_modalities=None,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — LiteLLM model-map path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_litellm_known_vision_model_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_base_model_wins_over_model_name():
|
||||
"""Azure-style entries pass model_name=deployment_id and put the
|
||||
canonical sku in litellm_params.base_model. The resolver must
|
||||
consult base_model first or the deployment id (which LiteLLM
|
||||
doesn't know) would shadow the real capability."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment-id",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_unknown_model_default_allows():
|
||||
"""Default-allow on unknown — the safety net is the actual block."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_known_text_only_returns_false():
|
||||
"""A model that LiteLLM explicitly knows is text-only resolves to
|
||||
False even via the catalog resolver. ``deepseek-chat`` (the
|
||||
DeepSeek-V3 chat sku) is in the map without supports_vision and
|
||||
LiteLLM's `supports_vision` returns False."""
|
||||
# Sanity: confirm the helper's negative path. We use a small model
|
||||
# known not to support vision per the map.
|
||||
result = derive_supports_image_input(
|
||||
provider="DEEPSEEK",
|
||||
model_name="deepseek-chat",
|
||||
)
|
||||
# We accept either False (LiteLLM said explicit no) or True
|
||||
# (default-allow if the entry isn't mapped on this version) — the
|
||||
# invariant is that the resolver never *raises* on a known-text-only
|
||||
# provider/model. The behaviour-binding assertion lives in
|
||||
# ``test_is_known_text_only_chat_model_explicit_false`` below.
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_known_text_only_chat_model — strict opt-out semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_vision_model():
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_unknown_model():
|
||||
"""Strict opt-out: missing from the map ≠ text-only. The safety net
|
||||
must NOT fire for an unmapped model — that's the regression we're
|
||||
fixing."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
|
||||
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
|
||||
helper swallows the exception and returns False so the safety net
|
||||
doesn't fire on a transient lookup failure."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise ValueError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
|
||||
we exercise the opt-out path deterministically. Using a stub keeps
|
||||
the test stable across LiteLLM map updates."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
|
||||
"""A model entry without ``supports_vision`` at all is treated as
|
||||
'unknown' — strict opt-out means False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"max_input_tokens": 8192} # no supports_vision
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
"""Unit tests for ``QuotaCheckedVisionLLM``.
|
||||
|
||||
Validates that:
|
||||
|
||||
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
|
||||
enforcement) and forwards the inner LLM's response on success.
|
||||
* The wrapper proxies non-overridden attributes to the inner LLM
|
||||
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
|
||||
still work without quota gating (they're not used in indexing today).
|
||||
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
|
||||
bubbles it up — the ETL pipeline catches that and falls back to OCR.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeInnerLLM:
|
||||
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
|
||||
|
||||
def __init__(self, response: Any = "OCR'd content") -> None:
|
||||
self._response = response
|
||||
self.ainvoke_calls: list[Any] = []
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
self.ainvoke_calls.append(input)
|
||||
return self._response
|
||||
|
||||
def some_other_method(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _passthrough_billable_call(**_kwargs):
|
||||
"""Stand-in for billable_call that always allows the call to run."""
|
||||
|
||||
class _Acc:
|
||||
total_cost_micros = 0
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
grand_total = 0
|
||||
calls: list[Any] = []
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
return {}
|
||||
|
||||
yield _Acc()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_routes_through_billable_call(monkeypatch):
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
captured_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _spy_billable_call(**kwargs):
|
||||
captured_kwargs.append(kwargs)
|
||||
async with _passthrough_billable_call() as acc:
|
||||
yield acc
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_spy_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM(response="A red apple on a white table")
|
||||
user_id = uuid4()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=user_id,
|
||||
search_space_id=99,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
result = await wrapper.ainvoke([{"text": "what is this?"}])
|
||||
assert result == "A red apple on a white table"
|
||||
assert len(inner.ainvoke_calls) == 1
|
||||
assert len(captured_kwargs) == 1
|
||||
bc_kwargs = captured_kwargs[0]
|
||||
assert bc_kwargs["user_id"] == user_id
|
||||
assert bc_kwargs["search_space_id"] == 99
|
||||
assert bc_kwargs["billing_tier"] == "premium"
|
||||
assert bc_kwargs["base_model"] == "openai/gpt-4o"
|
||||
assert bc_kwargs["quota_reserve_tokens"] == 4000
|
||||
assert bc_kwargs["usage_type"] == "vision_extraction"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="vision_extraction",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # unreachable but required for asynccontextmanager type
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_denying_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaInsufficientError):
|
||||
await wrapper.ainvoke([{"text": "x"}])
|
||||
|
||||
# Inner LLM never ran on a denied reservation.
|
||||
assert inner.ainvoke_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxies_non_overridden_attributes_to_inner():
|
||||
"""``__getattr__`` forwards anything not on the proxy itself, so any
|
||||
method we didn't explicitly override (``invoke``, ``astream``,
|
||||
``with_structured_output``, etc.) still works — just without quota
|
||||
gating, which is fine because the indexer only ever calls ainvoke.
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
# ``some_other_method`` is on the inner only.
|
||||
assert wrapper.some_other_method(7) == 14
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
|
||||
|
||||
Capability is sourced from two places, in order of preference:
|
||||
|
||||
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
|
||||
(authoritative — OpenRouter publishes per-model modalities directly).
|
||||
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
|
||||
YAML / BYOK configs that don't carry an explicit operator override.
|
||||
|
||||
The catalog default is *True* (conservative-allow): an unknown / unmapped
|
||||
model is not pre-judged. The streaming-task safety net
|
||||
(``is_known_text_only_chat_model``) is the only place a False actually
|
||||
blocks a request — and it requires LiteLLM to *explicitly* mark the model
|
||||
as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
_generate_configs,
|
||||
_supports_image_input,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_SETTINGS_BASE: dict = {
|
||||
"api_key": "sk-or-test",
|
||||
"id_offset": -10_000,
|
||||
"rpm": 200,
|
||||
"tpm": 1_000_000,
|
||||
"free_rpm": 20,
|
||||
"free_tpm": 100_000,
|
||||
"anonymous_enabled_paid": False,
|
||||
"anonymous_enabled_free": True,
|
||||
"quota_reserve_tokens": 4000,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_image_input helper (OpenRouter modalities)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_image_input_true_for_multimodal():
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_for_text_only():
|
||||
"""The exact failure mode the safety net guards against — DeepSeek V3
|
||||
is a text-in/text-out model and would 404 if forwarded image_url."""
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_when_modalities_missing():
|
||||
"""Defensive: missing architecture is treated as text-only at the
|
||||
OpenRouter helper level. The wider catalog resolver
|
||||
(`derive_supports_image_input`) only consults modalities when they
|
||||
are non-empty, otherwise it falls back to LiteLLM."""
|
||||
assert _supports_image_input({"id": "weird/model"}) is False
|
||||
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{"id": "weird/model", "architecture": {"input_modalities": None}}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_configs threads the flag onto every emitted chat config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_configs_emits_supports_image_input():
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
by_model = {c["model_name"]: c for c in cfgs}
|
||||
|
||||
gpt = by_model["openai/gpt-4o"]
|
||||
assert gpt["supports_image_input"] is True
|
||||
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
|
||||
assert deepseek["supports_image_input"] is False
|
||||
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML loader: defer to derive_supports_image_input on unannotated entries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
|
||||
"""The regression case: an Azure GPT-5.x YAML entry without a
|
||||
``supports_image_input`` override should resolve to True via LiteLLM's
|
||||
model map (which says ``supports_vision: true``). Previously this
|
||||
defaulted to False, blocking every image turn for vision-capable
|
||||
YAML configs."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -2
|
||||
name: Azure GPT-4o
|
||||
provider: AZURE_OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: GPT-4o
|
||||
provider: OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
supports_image_input: false
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
# Operator override always wins, even against LiteLLM's True.
|
||||
assert configs[0]["supports_image_input"] is False
|
||||
|
||||
|
||||
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
|
||||
"""Unknown / unmapped model in YAML: default-allow. The streaming
|
||||
safety net (which requires an explicit-False from LiteLLM) is the
|
||||
only place a real block happens, so we don't lock the user out of
|
||||
a freshly added third-party entry the catalog can't introspect."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: Some Brand New Model
|
||||
provider: CUSTOM
|
||||
custom_provider: brand_new_proxy
|
||||
model_name: brand-new-model-x9
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentConfig threads the flag through both YAML and Auto / BYOK
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_explicit_overrides_resolver():
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg_text_only = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "Text Only Override",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": False,
|
||||
}
|
||||
)
|
||||
cfg_explicit_vision = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-4o",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
assert cfg_text_only.supports_image_input is False
|
||||
assert cfg_explicit_vision.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_unannotated_uses_resolver():
|
||||
"""Without an explicit YAML key, AgentConfig defers to the catalog
|
||||
resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (no override)",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
}
|
||||
)
|
||||
assert cfg.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_auto_mode_supports_image_input():
|
||||
"""Auto routes across the pool. We optimistically allow image input
|
||||
so users can keep their selection on Auto with a vision-capable
|
||||
deployment somewhere in the pool. The router's own `allowed_fails`
|
||||
handles non-vision deployments via fallback."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
auto = AgentConfig.from_auto_mode()
|
||||
assert auto.supports_image_input is True
|
||||
|
|
@ -0,0 +1,515 @@
|
|||
"""Cost-based premium quota unit tests.
|
||||
|
||||
Covers the USD-micro behaviour added in migration 140:
|
||||
|
||||
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
|
||||
calls in a turn — used as the debit amount when ``agent_config.is_premium``
|
||||
is true, regardless of which underlying model produced each call. This
|
||||
preserves the prior "premium turn → all calls in turn count" rule from the
|
||||
token-based system.
|
||||
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
|
||||
clamps to a sane floor when pricing is unknown, and respects the
|
||||
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
|
||||
can't lock the whole balance on one call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TurnTokenAccumulator — premium-turn debit semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_total_cost_micros_sums_premium_and_free_calls():
|
||||
"""A premium turn that also called a free sub-agent debits the union.
|
||||
|
||||
The plan deliberately preserved the existing "premium turn → all calls
|
||||
count" behaviour because per-call premium filtering relied on
|
||||
``LLMRouterService._premium_model_strings`` which only covers router-pool
|
||||
deployments. ``total_cost_micros`` therefore must include free-model
|
||||
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
|
||||
call's actual provider cost.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
# Premium model (e.g. claude-opus): non-zero cost.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=1200,
|
||||
completion_tokens=400,
|
||||
total_tokens=1600,
|
||||
cost_micros=12_345,
|
||||
)
|
||||
# Free sub-agent (e.g. title-gen on a free model): zero cost.
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=120,
|
||||
completion_tokens=20,
|
||||
total_tokens=140,
|
||||
cost_micros=0,
|
||||
)
|
||||
# A second premium-priced call within the same turn.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=800,
|
||||
completion_tokens=200,
|
||||
total_tokens=1000,
|
||||
cost_micros=7_500,
|
||||
)
|
||||
|
||||
assert acc.total_cost_micros == 12_345 + 0 + 7_500
|
||||
# Token totals stay correct so the FE display path still works.
|
||||
assert acc.grand_total == 1600 + 140 + 1000
|
||||
|
||||
|
||||
def test_total_cost_micros_zero_when_no_calls():
|
||||
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
assert acc.total_cost_micros == 0
|
||||
assert acc.grand_total == 0
|
||||
|
||||
|
||||
def test_per_message_summary_groups_cost_by_model():
|
||||
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
|
||||
SSE ``model_breakdown`` payload reports actual USD spend per provider.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_micros=4_000,
|
||||
)
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
total_tokens=300,
|
||||
cost_micros=8_000,
|
||||
)
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
total_tokens=60,
|
||||
cost_micros=200,
|
||||
)
|
||||
|
||||
summary = acc.per_message_summary()
|
||||
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
|
||||
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
|
||||
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||
|
||||
|
||||
def test_serialized_calls_includes_cost_micros():
|
||||
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||
payload; cost_micros must be present on each entry so the FE message-info
|
||||
dropdown can render per-call USD.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="m",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=42,
|
||||
)
|
||||
serialized = acc.serialized_calls()
|
||||
assert serialized == [
|
||||
{
|
||||
"model": "m",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost_micros": 42,
|
||||
"call_kind": "chat",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_call_reserve_micros — sizing and clamping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
|
||||
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
|
||||
helper falls back to the 100-micro floor — small enough that a user with
|
||||
$0.0001 left can still send a tiny request, but non-zero so we still gate
|
||||
against an empty balance.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
def _raise(_name):
|
||||
raise KeyError("unknown")
|
||||
|
||||
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="nonexistent-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
assert micros == 100
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
|
||||
"""LiteLLM may *return* a model with both cost-per-token fields at 0
|
||||
(pricing not yet registered). The helper must not multiply 0 x tokens
|
||||
and end up reserving 0 — it must clamp to the floor.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
|
||||
raising=False,
|
||||
)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="some-pending-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
|
||||
def test_reserve_scales_with_model_cost(monkeypatch):
|
||||
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
|
||||
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
|
||||
some small artificial cap — that was the bug the plan called out.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 15e-6,
|
||||
"output_cost_per_token": 75e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="claude-3-opus",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
|
||||
assert micros == 360_000
|
||||
|
||||
|
||||
def test_reserve_clamps_to_max_ceiling(monkeypatch):
|
||||
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
|
||||
nominally compute to $4 = 4_000_000 micros. The ceiling
|
||||
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
|
||||
can't lock the user's whole balance on one call.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-3,
|
||||
"output_cost_per_token": 0,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="oops-misconfigured",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == 1_000_000
|
||||
|
||||
|
||||
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
|
||||
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
|
||||
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
|
||||
so anonymous-style configs still reserve the operator-tunable default.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 1e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=None
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=0
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TokenTrackingCallback — image vs chat usage shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeImageUsage:
|
||||
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
total_tokens: int | None = None,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
if total_tokens is not None:
|
||||
self.total_tokens = total_tokens
|
||||
|
||||
|
||||
class _FakeImageResponse:
|
||||
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
|
||||
``type(...).__name__`` probe routes to the image branch.
|
||||
"""
|
||||
|
||||
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
|
||||
self.usage = usage
|
||||
if response_cost is not None:
|
||||
self._hidden_params = {"response_cost": response_cost}
|
||||
|
||||
|
||||
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
|
||||
# the callback. We can't simply name the class ``ImageResponse`` because
|
||||
# the test runner sometimes imports test modules in surprising ways and
|
||||
# we want to be explicit.
|
||||
_FakeImageResponse.__name__ = "ImageResponse"
|
||||
|
||||
|
||||
class _FakeChatUsage:
|
||||
def __init__(self, prompt: int, completion: int):
|
||||
self.prompt_tokens = prompt
|
||||
self.completion_tokens = completion
|
||||
self.total_tokens = prompt + completion
|
||||
|
||||
|
||||
class _FakeChatResponse:
|
||||
def __init__(self, usage: _FakeChatUsage):
|
||||
self.usage = usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_reads_image_usage_input_output_tokens():
|
||||
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
|
||||
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
|
||||
prompt_tokens/completion_tokens which is the chat shape.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
|
||||
response_cost=0.04, # $0.04 per image
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 42
|
||||
assert call.completion_tokens == 8
|
||||
assert call.total_tokens == 50
|
||||
# 0.04 USD = 40_000 micros
|
||||
assert call.cost_micros == 40_000
|
||||
assert call.call_kind == "image_generation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_chat_path_unchanged():
|
||||
"""Chat responses must still read prompt_tokens/completion_tokens."""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={
|
||||
"model": "openrouter/anthropic/claude-3-5-sonnet",
|
||||
"response_cost": 0.0036,
|
||||
},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 120
|
||||
assert call.completion_tokens == 30
|
||||
assert call.total_tokens == 150
|
||||
assert call.cost_micros == 3_600
|
||||
assert call.call_kind == "chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
|
||||
"""When OpenRouter omits ``usage.cost`` LiteLLM's
|
||||
``default_image_cost_calculator`` raises. The defensive image branch in
|
||||
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
|
||||
chat-shaped and would raise too) — it returns 0 with a WARNING log.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# Force completion_cost to raise the same way OpenRouter image-gen fails.
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise ValueError("model_cost: missing entry for openrouter image model")
|
||||
|
||||
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
|
||||
|
||||
# And make sure cost_per_token is NEVER called for the image path —
|
||||
# if it were, our ``is_image=True`` branch is broken.
|
||||
cost_per_token_calls: list = []
|
||||
|
||||
def _record_cost_per_token(**kwargs):
|
||||
cost_per_token_calls.append(kwargs)
|
||||
return (0.0, 0.0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm, "cost_per_token", _record_cost_per_token, raising=False
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
|
||||
assert len(acc.calls) == 1
|
||||
assert acc.calls[0].cost_micros == 0
|
||||
assert acc.calls[0].call_kind == "image_generation"
|
||||
# The image branch must short-circuit before cost_per_token.
|
||||
assert cost_per_token_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scoped_turn — ContextVar reset semantics (issue B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_restores_outer_accumulator():
|
||||
"""``scoped_turn`` must restore the previous ContextVar value on exit
|
||||
so a per-call wrapper inside an outer chat turn doesn't leak its
|
||||
accumulator outward (which would cause double-debit at chat-turn exit).
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
start_turn,
|
||||
)
|
||||
|
||||
outer = start_turn()
|
||||
assert get_current_accumulator() is outer
|
||||
|
||||
async with scoped_turn() as inner:
|
||||
assert get_current_accumulator() is inner
|
||||
assert inner is not outer
|
||||
inner.add(
|
||||
model="x",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=5,
|
||||
)
|
||||
|
||||
# After exit the outer accumulator is restored unchanged.
|
||||
assert get_current_accumulator() is outer
|
||||
assert outer.total_cost_micros == 0
|
||||
assert len(outer.calls) == 0
|
||||
# The inner accumulator captured the call but didn't bleed into outer.
|
||||
assert inner.total_cost_micros == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_resets_to_none_when_no_outer():
|
||||
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
|
||||
indexing job) must leave the ContextVar at ``None`` on exit so the
|
||||
next *unrelated* request starts clean.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
_turn_accumulator,
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# ContextVar default is None for a fresh test isolated context. We
|
||||
# simulate "no outer" explicitly to be robust against test order.
|
||||
token = _turn_accumulator.set(None)
|
||||
try:
|
||||
assert get_current_accumulator() is None
|
||||
async with scoped_turn() as acc:
|
||||
assert get_current_accumulator() is acc
|
||||
assert get_current_accumulator() is None
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
|
||||
defaults from ``litellm.api_base`` either.
|
||||
|
||||
Vision shares the same shape as image-gen — global YAML / OpenRouter
|
||||
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
|
||||
call sites would silently drop the empty string and inherit
|
||||
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
|
||||
construction so we test the kwargs we hand to it instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch: an OpenRouter vision config with
|
||||
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
|
||||
``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
|
||||
never silently absent."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.user_id = "user-x"
|
||||
search_space.vision_llm_config_id = cfg["id"]
|
||||
|
||||
session = AsyncMock()
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session.execute.return_value = result
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class FakeSanitized:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.vision_llm_router_service.get_global_vision_llm_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch(
|
||||
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
|
||||
new=FakeSanitized,
|
||||
),
|
||||
):
|
||||
await llm_service.get_vision_llm(session=session, search_space_id=1)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-4o"
|
||||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
|
||||
so the resolver has to apply at deployment construction time too."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
|
||||
|
|
@ -51,22 +51,34 @@ class _FakeToolMessage:
|
|||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeInterrupt:
|
||||
value: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTask:
|
||||
interrupts: tuple[_FakeInterrupt, ...] = ()
|
||||
|
||||
|
||||
class _FakeAgentState:
|
||||
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, tasks: list[Any] | None = None) -> None:
|
||||
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||
# and an empty ``tasks`` list keeps the post-stream interrupt
|
||||
# check a no-op too.
|
||||
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
|
||||
self.values: dict[str, Any] = {}
|
||||
self.tasks: list[Any] = []
|
||||
self.tasks: list[Any] = tasks or []
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Replays a list of ``astream_events`` events."""
|
||||
|
||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||
def __init__(
|
||||
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> None:
|
||||
self._events = events
|
||||
self._state = state or _FakeAgentState()
|
||||
|
||||
async def astream_events( # type: ignore[no-untyped-def]
|
||||
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||
|
|
@ -79,7 +91,7 @@ class _FakeAgent:
|
|||
# Called once after astream_events drains so the cloud-fallback
|
||||
# safety net can inspect staged filesystem work. The fake stays
|
||||
# empty so the safety net is a no-op.
|
||||
return _FakeAgentState()
|
||||
return self._state
|
||||
|
||||
|
||||
def _model_stream(
|
||||
|
|
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
async def _drain(
|
||||
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||
SSE payloads (parsed JSON) it yielded.
|
||||
"""
|
||||
agent = _FakeAgent(events)
|
||||
agent = _FakeAgent(events, state=state)
|
||||
service = VercelStreamingService()
|
||||
result = StreamResult()
|
||||
config = {"configurable": {"thread_id": "test-thread"}}
|
||||
|
|
@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
|||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
interrupt_payload = {
|
||||
"type": "calendar_event_create",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {"summary": "mom bday"},
|
||||
},
|
||||
"context": {},
|
||||
}
|
||||
state = _FakeAgentState(
|
||||
tasks=[
|
||||
_FakeTask(interrupts=()),
|
||||
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
|
||||
]
|
||||
)
|
||||
|
||||
payloads = await _drain([], state=state)
|
||||
|
||||
interrupts = _of_type(payloads, "data-interrupt-request")
|
||||
assert len(interrupts) == 1
|
||||
assert (
|
||||
interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
|
||||
)
|
||||
|
|
|
|||
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
"""Regression tests for ``run_async_celery_task``.
|
||||
|
||||
These tests pin down the production bug observed on 2026-05-02 where
|
||||
the video-presentation Celery task hung at ``[billable_call] finalize``
|
||||
because the shared ``app.db.engine`` had pooled asyncpg connections
|
||||
bound to a *previous* task's now-closed event loop. Reusing such a
|
||||
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
|
||||
(the proactor is None because the loop is gone) and can hang forever
|
||||
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
|
||||
|
||||
The fix is ``run_async_celery_task``: a small helper that runs every
|
||||
async celery task body inside a fresh event loop and disposes the
|
||||
shared engine pool both before (defends against a previous task that
|
||||
crashed) and after (releases connections we opened on this loop).
|
||||
|
||||
Tests here exercise the helper with a stub engine that records
|
||||
``dispose()`` calls and panics if a coroutine produced by one loop is
|
||||
awaited on another — mirroring the real asyncpg behaviour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub engine that emulates the asyncpg-on-stale-loop crash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StaleLoopEngine:
|
||||
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
|
||||
|
||||
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
|
||||
the running event loop id so tests can assert it ran on *each*
|
||||
fresh loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dispose_loop_ids: list[int] = []
|
||||
|
||||
async def dispose(self) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self.dispose_loop_ids.append(id(loop))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
|
||||
"""Patch ``from app.db import engine as shared_engine`` lookup.
|
||||
|
||||
The helper imports lazily inside the function body, so we have to
|
||||
patch the attribute on the already-loaded ``app.db`` module.
|
||||
"""
|
||||
import app.db as app_db
|
||||
|
||||
original = getattr(app_db, "engine", None)
|
||||
app_db.engine = stub # type: ignore[attr-defined]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if original is None:
|
||||
with pytest.raises(AttributeError):
|
||||
_ = app_db.engine
|
||||
else:
|
||||
app_db.engine = original # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
|
||||
"""Happy path: the coroutine result is returned, and the shared
|
||||
engine is disposed both before and after the task body runs.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> str:
|
||||
# Engine should already have been disposed once before we run.
|
||||
assert len(stub.dispose_loop_ids) == 1
|
||||
return "ok"
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
result = run_async_celery_task(_body)
|
||||
|
||||
assert result == "ok"
|
||||
# Once before the body, once after (in finally).
|
||||
assert len(stub.dispose_loop_ids) == 2
|
||||
# Both disposes ran on the SAME (fresh) loop the task body used.
|
||||
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
|
||||
|
||||
|
||||
def test_runner_creates_fresh_loop_per_invocation() -> None:
|
||||
"""Each call must spin its own loop. Without this guarantee a
|
||||
previous task's loop would be reused and the asyncpg-stale-loop
|
||||
crash would never be avoided.
|
||||
"""
|
||||
import app.tasks.celery_tasks as celery_tasks_pkg
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
new_loop_calls = 0
|
||||
closed_loops: list[bool] = []
|
||||
|
||||
real_new_event_loop = asyncio.new_event_loop
|
||||
|
||||
def _counting_new_loop() -> asyncio.AbstractEventLoop:
|
||||
nonlocal new_loop_calls
|
||||
new_loop_calls += 1
|
||||
loop = real_new_event_loop()
|
||||
# Hook close() so we can verify each loop was closed properly
|
||||
# before the next one was created.
|
||||
original_close = loop.close
|
||||
|
||||
def _tracked_close() -> None:
|
||||
closed_loops.append(True)
|
||||
original_close()
|
||||
|
||||
loop.close = _tracked_close # type: ignore[method-assign]
|
||||
return loop
|
||||
|
||||
async def _body() -> None:
|
||||
# Loop is alive and current at body execution time.
|
||||
running = asyncio.get_running_loop()
|
||||
assert not running.is_closed()
|
||||
|
||||
with (
|
||||
_patch_shared_engine(stub),
|
||||
patch.object(asyncio, "new_event_loop", _counting_new_loop),
|
||||
):
|
||||
for _ in range(3):
|
||||
celery_tasks_pkg.run_async_celery_task(_body)
|
||||
|
||||
assert new_loop_calls == 3
|
||||
assert closed_loops == [True, True, True]
|
||||
# Each invocation disposed twice (before + after).
|
||||
assert len(stub.dispose_loop_ids) == 6
|
||||
|
||||
|
||||
def test_runner_disposes_engine_even_when_body_raises() -> None:
|
||||
"""Cleanup MUST run on the failure path too — otherwise stale
|
||||
connections leak into the next task and cause the original hang.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
class _BoomError(RuntimeError):
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
raise _BoomError("kaboom")
|
||||
|
||||
with _patch_shared_engine(stub), pytest.raises(_BoomError):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
|
||||
|
||||
|
||||
def test_runner_swallows_dispose_errors() -> None:
|
||||
"""A flaky engine.dispose() must NEVER take down a celery task.
|
||||
|
||||
Production scenario: the very first dispose (before the body runs)
|
||||
might hit a partially-initialised engine; the helper logs and
|
||||
moves on. The task body still runs; the result is still returned.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
class _AngryEngine:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def dispose(self) -> None:
|
||||
self.calls += 1
|
||||
raise RuntimeError("dispose() blew up")
|
||||
|
||||
stub = _AngryEngine()
|
||||
|
||||
async def _body() -> int:
|
||||
return 42
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
assert run_async_celery_task(_body) == 42
|
||||
|
||||
assert stub.calls == 2 # before + after both attempted
|
||||
|
||||
|
||||
def test_runner_propagates_value_from_async_body() -> None:
|
||||
"""Sanity: pass-through of any pickleable celery return value."""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> dict[str, object]:
|
||||
return {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
out = run_async_celery_task(_body)
|
||||
|
||||
assert out == {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
|
||||
def test_video_presentation_task_uses_runner_helper() -> None:
|
||||
"""Defence-in-depth: confirm the celery task module imports
|
||||
``run_async_celery_task``. If a future refactor inlines a
|
||||
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
|
||||
the original hang will return.
|
||||
"""
|
||||
# The module's task body should not contain a manual new_event_loop
|
||||
# call — that's exactly what the helper exists to centralise.
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
src = inspect.getsource(video_presentation_tasks)
|
||||
assert "run_async_celery_task" in src, (
|
||||
"video_presentation_tasks.py must use run_async_celery_task; "
|
||||
"manual asyncio.new_event_loop() in a celery task hangs on the "
|
||||
"shared SQLAlchemy pool when reused across tasks."
|
||||
)
|
||||
assert "asyncio.new_event_loop" not in src, (
|
||||
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
|
||||
"call — route every async task through run_async_celery_task to "
|
||||
"avoid the stale-pool hang."
|
||||
)
|
||||
|
||||
|
||||
def test_podcast_task_uses_runner_helper() -> None:
|
||||
"""Symmetric assertion for the podcast task — same root cause, same
|
||||
fix, same regression risk.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
src = inspect.getsource(podcast_tasks)
|
||||
assert "run_async_celery_task" in src
|
||||
assert "asyncio.new_event_loop" not in src
|
||||
|
||||
|
||||
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
|
||||
"""If the task body created any async generators that didn't get
|
||||
fully iterated, we must still call ``loop.shutdown_asyncgens()``
|
||||
before closing — otherwise we leak event-loop bound resources
|
||||
that re-emerge as ``RuntimeError: Event loop is closed`` later.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _agen():
|
||||
try:
|
||||
yield 1
|
||||
yield 2
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
# Iterate the agen partially, then leave it dangling — exactly
|
||||
# the situation shutdown_asyncgens() is designed to clean up.
|
||||
async for v in _agen():
|
||||
if v == 1:
|
||||
break
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# By the time the helper returns, garbage collection + shutdown_asyncgens
|
||||
# should have ensured no live async-gen references remain. We don't
|
||||
# assert agen.closed directly (it depends on GC ordering); the real
|
||||
# contract is "no warnings, no event-loop-closed errors". A successful
|
||||
# second invocation proves the loop was cleaned up properly.
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# Force a GC pass to surface any 'coroutine was never awaited'
|
||||
# warnings that would indicate the cleanup is broken.
|
||||
gc.collect()
|
||||
|
||||
|
||||
def test_runner_uses_proactor_loop_on_windows() -> None:
|
||||
"""On Windows the celery worker preselects a Proactor policy so
|
||||
subprocess (ffmpeg) calls work. The helper must not silently fall
|
||||
back to a Selector loop and re-break video/podcast generation.
|
||||
"""
|
||||
if not sys.platform.startswith("win"):
|
||||
pytest.skip("Windows-specific event-loop policy assertion")
|
||||
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
# Mirror the policy set at the top of every Windows celery task.
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
observed: list[str] = []
|
||||
|
||||
async def _body() -> None:
|
||||
observed.append(type(asyncio.get_running_loop()).__name__)
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert observed == ["ProactorEventLoop"]
|
||||
388
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
388
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
"""Unit tests for podcast Celery task billing integration.
|
||||
|
||||
Validates ``_generate_content_podcast`` correctly wraps
|
||||
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
|
||||
search-space owner's billing decision, and degrades cleanly when the
|
||||
resolver fails or premium credit is exhausted.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Happy-path free config: resolver → ``billable_call`` enters with
|
||||
``usage_type='podcast_generation'`` and the configured reserve override,
|
||||
graph runs, podcast row flips to ``READY``.
|
||||
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
|
||||
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
|
||||
carries ``reason='premium_quota_exhausted'``.
|
||||
* Resolver failure: ``ValueError`` from the resolver → podcast row flips
|
||||
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, podcast):
|
||||
self._podcast = podcast
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._podcast)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
|
||||
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
|
||||
inside helpers keeps this fixture cheap."""
|
||||
return SimpleNamespace(
|
||||
id=podcast_id,
|
||||
title="Test Podcast",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
podcast_transcript=None,
|
||||
file_location=None,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
"""Stand-in for ``billable_call`` that records its kwargs and yields a
|
||||
no-op accumulator-shaped object."""
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
"""Happy path: free billing tier still wraps the graph call so the
|
||||
audit row is recorded. Verifies kwargs threading."""
|
||||
from app.config import config as app_config
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=7, thread_id=99)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 555
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {
|
||||
"podcast_transcript": [
|
||||
SimpleNamespace(speaker_id=0, dialog="Hi"),
|
||||
SimpleNamespace(speaker_id=1, dialog="Hello"),
|
||||
],
|
||||
"final_podcast_file_path": "/tmp/podcast.wav",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hello world",
|
||||
search_space_id=555,
|
||||
user_prompt="make it short",
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["podcast_id"] == 7
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
assert podcast.file_location == "/tmp/podcast.wav"
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 555
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "podcast_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||
)
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"podcast_id": 7,
|
||||
"title": "Test Podcast",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
"""Premium resolution flows through to ``billable_call`` so the
|
||||
reserve/finalize path triggers."""
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast()
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
|
||||
"""When ``billable_call`` denies the reservation, the graph never
|
||||
runs and the podcast row flips to FAILED with the documented reason
|
||||
code."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=8)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=8,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 8,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=10)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "billable_call", _settlement_failing_billable_call
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=10,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 10,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||
cleanly without invoking the graph."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=9)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 555 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=9,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 9,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
"""Predicate-level test for the chat streaming safety net.
|
||||
|
||||
The safety net in ``stream_new_chat`` rejects an image turn early with
|
||||
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
|
||||
selected model is *known* to be text-only. The earlier round of this
|
||||
work used a strict opt-in flag (``supports_image_input`` defaulting to
|
||||
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
|
||||
deployments — this is the regression we're fixing.
|
||||
|
||||
The new predicate is :func:`is_known_text_only_chat_model`, which
|
||||
returns True only when LiteLLM's authoritative model map *explicitly*
|
||||
sets ``supports_vision=False``. Anything else (vision True, missing
|
||||
key, exception) returns False so the request flows through to the
|
||||
provider.
|
||||
|
||||
We exercise the predicate directly here rather than driving the full
|
||||
``stream_new_chat`` generator — covering the gate in isolation keeps
|
||||
the test focused on the regression while the generator's wider behavior
|
||||
is exercised by the integration suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
||||
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
|
||||
vision-capable per LiteLLM's model map. The previous round's
|
||||
blanket-False default blocked it; the new predicate must NOT mark
|
||||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_unknown_model():
|
||||
"""Default-pass on unknown — the safety net only blocks definitive
|
||||
text-only confirmations. A freshly added third-party model that
|
||||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
||||
"""Transient ``litellm.get_model_info`` exception ≠ block. The
|
||||
helper swallows the error and treats it as 'unknown' → False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM to assert the only path that returns True is the
|
||||
explicit ``supports_vision=False`` case. Anything else (True,
|
||||
None, missing key) returns False from the predicate."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info_explicit_false(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def _info_true(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def _info_missing(**_kwargs):
|
||||
return {"max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -0,0 +1,398 @@
|
|||
"""Unit tests for video-presentation Celery task billing integration.
|
||||
|
||||
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
|
||||
Validates the same wrap-graph-in-billable_call pattern and ensures the
|
||||
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
|
||||
threaded through.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Free config: graph runs, ``billable_call`` invoked with the video
|
||||
reserve override.
|
||||
* Premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
|
||||
* Resolver failure: row → FAILED with ``billing_resolution_failed``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, video):
|
||||
self._video = video
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._video)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=video_id,
|
||||
title="Test Presentation",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
slides=None,
|
||||
scene_codes=None,
|
||||
)
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=11, thread_id=99)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 777
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=11,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["video_presentation_id"] == 11
|
||||
assert video.status == VideoPresentationStatus.READY
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 777
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "video_presentation_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
|
||||
)
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"video_presentation_id": 11,
|
||||
"title": "Test Presentation",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video()
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=11,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=12)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks, "billable_call", _denying_billable_call
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=12,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 12,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=14)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"billable_call",
|
||||
_settlement_failing_billable_call,
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=14,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 14,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=13)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 777 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_failing_resolver,
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=13,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 13,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
|
@ -271,6 +271,66 @@ async def test_preflight_skipped_for_auto_router_model():
|
|||
await _preflight_llm(fake_llm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settle_speculative_agent_build_swallows_exceptions():
|
||||
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
|
||||
caller can safely re-touch the request-scoped session afterwards.
|
||||
|
||||
The helper guards the parallel preflight + agent-build path: when the
|
||||
speculative build is being discarded (429 or non-429 preflight failure)
|
||||
we await it solely to release any in-flight ``AsyncSession`` usage —
|
||||
the build's outcome is irrelevant. Any exception (including
|
||||
``CancelledError``) leaking out would skip the caller's recovery flow
|
||||
and re-introduce the very session-concurrency hazard the helper exists
|
||||
to prevent.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
||||
|
||||
async def _raises() -> None:
|
||||
raise RuntimeError("speculative build crashed")
|
||||
|
||||
async def _succeeds() -> str:
|
||||
return "agent"
|
||||
|
||||
async def _slow() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
for coro in (_raises(), _succeeds(), _slow()):
|
||||
task = asyncio.create_task(coro)
|
||||
await _settle_speculative_agent_build(task)
|
||||
assert task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settle_speculative_agent_build_handles_already_done_task():
|
||||
"""Done tasks (success or failure) must still be settled without raising."""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
||||
|
||||
async def _ok() -> str:
|
||||
return "ok"
|
||||
|
||||
async def _bad() -> None:
|
||||
raise ValueError("nope")
|
||||
|
||||
ok_task = asyncio.create_task(_ok())
|
||||
bad_task = asyncio.create_task(_bad())
|
||||
# Drive both to completion before settling.
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await _settle_speculative_agent_build(ok_task)
|
||||
await _settle_speculative_agent_build(bad_task)
|
||||
assert ok_task.result() == "ok"
|
||||
# ``bad_task`` exception was consumed by the settle helper; calling
|
||||
# ``.exception()`` after the fact must still return the original error
|
||||
# (the helper observes it but doesn't clear it).
|
||||
assert isinstance(bad_task.exception(), ValueError)
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue