mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-06 22:32:39 +02:00
Merge upstream/dev into feature/multi-agent
This commit is contained in:
commit
5119915f4f
278 changed files with 34669 additions and 8970 deletions
|
|
@ -226,6 +226,31 @@ class TestCompose:
|
|||
# Default block should NOT be present
|
||||
assert "<knowledge_base_only_policy>" not in prompt
|
||||
|
||||
def test_provider_hints_render_with_custom_system_instructions(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
"""Regression guard for the always-append decision: provider hints
|
||||
append AFTER a custom system prompt.
|
||||
|
||||
Provider hints are stylistic nudges (parallel tool-call rules,
|
||||
formatting guidance, etc.) that help the model regardless of
|
||||
what the system instructions say. Suppressing them when a
|
||||
custom prompt is set would partially defeat the per-family
|
||||
prompt machinery.
|
||||
"""
|
||||
prompt = compose_system_prompt(
|
||||
today=fixed_today,
|
||||
custom_system_instructions="You are a custom assistant.",
|
||||
model_name="anthropic/claude-3-5-sonnet",
|
||||
)
|
||||
assert "You are a custom assistant." in prompt
|
||||
assert "<provider_hints>" in prompt
|
||||
# The custom prompt must come BEFORE the provider hints so the
|
||||
# user's framing isn't drowned out by the stylistic nudges.
|
||||
assert prompt.index("You are a custom assistant.") < prompt.index(
|
||||
"<provider_hints>"
|
||||
)
|
||||
|
||||
def test_use_default_false_with_no_custom_yields_no_system_block(
|
||||
self, fixed_today: datetime
|
||||
) -> None:
|
||||
|
|
|
|||
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
268
surfsense_backend/tests/unit/agents/new_chat/test_agent_cache.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
"""Regression tests for the compiled-agent cache.
|
||||
|
||||
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
|
||||
build-failure non-caching) and the cache-key signature helpers that
|
||||
``create_surfsense_deep_agent`` relies on. The integration with
|
||||
``create_surfsense_deep_agent`` is covered separately by the streaming
|
||||
contract tests; this module focuses on the primitives so a regression
|
||||
in the cache implementation is caught before it reaches the agent
|
||||
factory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.agent_cache import (
|
||||
flags_signature,
|
||||
reload_for_tests,
|
||||
stable_hash,
|
||||
system_prompt_hash,
|
||||
tools_signature,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stable_hash + signature helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stable_hash_is_deterministic_across_calls() -> None:
|
||||
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_stable_hash_changes_when_any_part_changes() -> None:
|
||||
base = stable_hash("v1", 42, "thread-9")
|
||||
assert stable_hash("v1", 42, "thread-10") != base
|
||||
assert stable_hash("v2", 42, "thread-9") != base
|
||||
assert stable_hash("v1", 43, "thread-9") != base
|
||||
|
||||
|
||||
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
|
||||
"""Two tool lists with the same surface must hash identically.
|
||||
|
||||
The cache key MUST NOT change when the underlying ``BaseTool``
|
||||
instances are different Python objects (a fresh request constructs
|
||||
fresh tool instances every time). Hashing on ``(name, description)``
|
||||
keeps the cache hot across requests with identical tool surfaces.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FakeTool:
|
||||
name: str
|
||||
description: str
|
||||
|
||||
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
|
||||
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
|
||||
sig_a = tools_signature(
|
||||
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
sig_b = tools_signature(
|
||||
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
assert sig_a == sig_b, "tool order must not affect the signature"
|
||||
|
||||
# Adding a tool rotates the key.
|
||||
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
|
||||
sig_c = tools_signature(
|
||||
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
assert sig_c != sig_a
|
||||
|
||||
|
||||
def test_tools_signature_rotates_when_connector_set_changes() -> None:
|
||||
@dataclass
|
||||
class FakeTool:
|
||||
name: str
|
||||
description: str
|
||||
|
||||
tools = [FakeTool("a", "x")]
|
||||
base = tools_signature(
|
||||
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
|
||||
)
|
||||
added = tools_signature(
|
||||
tools,
|
||||
available_connectors=["NOTION", "SLACK"],
|
||||
available_document_types=["FILE"],
|
||||
)
|
||||
assert base != added, "adding a connector must rotate the cache key"
|
||||
|
||||
|
||||
def test_flags_signature_changes_when_flag_flips() -> None:
|
||||
@dataclass(frozen=True)
|
||||
class Flags:
|
||||
a: bool = True
|
||||
b: bool = False
|
||||
|
||||
base = flags_signature(Flags())
|
||||
flipped = flags_signature(Flags(b=True))
|
||||
assert base != flipped
|
||||
|
||||
|
||||
def test_system_prompt_hash_is_stable_and_distinct() -> None:
|
||||
p1 = "You are a helpful assistant."
|
||||
p2 = "You are a helpful assistant!" # one-character delta
|
||||
assert system_prompt_hash(p1) == system_prompt_hash(p1)
|
||||
assert system_prompt_hash(p1) != system_prompt_hash(p2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
builds = 0
|
||||
|
||||
async def builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is b, "cache must return the SAME object across hits"
|
||||
assert builds == 1, "builder must run exactly once"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_different_keys_get_different_instances() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k1", builder=builder)
|
||||
b = await cache.get_or_build("k2", builder=builder)
|
||||
assert a is not b
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_stale_entries_get_rebuilt() -> None:
|
||||
# ttl=0 means every read sees the entry as immediately stale.
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
|
||||
builds = 0
|
||||
|
||||
async def builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is not b, "stale entry must rebuild a fresh instance"
|
||||
assert builds == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_evicts_lru_when_full() -> None:
|
||||
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("a", builder=builder)
|
||||
_ = await cache.get_or_build("b", builder=builder)
|
||||
# Re-touch "a" so "b" is now the LRU victim.
|
||||
a_again = await cache.get_or_build("a", builder=builder)
|
||||
assert a_again is a
|
||||
# Inserting "c" should evict "b" (LRU), not "a".
|
||||
_ = await cache.get_or_build("c", builder=builder)
|
||||
assert cache.stats()["size"] == 2
|
||||
|
||||
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
|
||||
a_hit = await cache.get_or_build("a", builder=builder)
|
||||
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
|
||||
"""Two concurrent get_or_build calls on the same key must share one builder."""
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
build_started = asyncio.Event()
|
||||
builds = 0
|
||||
|
||||
async def slow_builder() -> object:
|
||||
nonlocal builds
|
||||
builds += 1
|
||||
build_started.set()
|
||||
# Yield control so the second waiter can race against us.
|
||||
await asyncio.sleep(0.05)
|
||||
return object()
|
||||
|
||||
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||
# Wait until the first builder has started, then race a second waiter.
|
||||
await build_started.wait()
|
||||
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
|
||||
|
||||
a, b = await asyncio.gather(task_a, task_b)
|
||||
assert a is b, "coalesced waiters must observe the same value"
|
||||
assert builds == 1, "concurrent cold misses must collapse to ONE build"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_does_not_store_failed_builds() -> None:
|
||||
"""A builder that raises must NOT poison the cache.
|
||||
|
||||
The next caller for the same key must run the builder again (not
|
||||
re-raise the cached exception).
|
||||
"""
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
attempts = 0
|
||||
|
||||
async def flaky_builder() -> object:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("transient")
|
||||
return object()
|
||||
|
||||
with pytest.raises(RuntimeError, match="transient"):
|
||||
await cache.get_or_build("k", builder=flaky_builder)
|
||||
|
||||
# Second call must retry — not re-raise the cached exception.
|
||||
value = await cache.get_or_build("k", builder=flaky_builder)
|
||||
assert value is not None
|
||||
assert attempts == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidate_drops_entry() -> None:
|
||||
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
a = await cache.get_or_build("k", builder=builder)
|
||||
assert cache.invalidate("k") is True
|
||||
b = await cache.get_or_build("k", builder=builder)
|
||||
assert a is not b, "post-invalidation lookup must rebuild"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
|
||||
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
|
||||
|
||||
async def builder() -> object:
|
||||
return object()
|
||||
|
||||
await cache.get_or_build("user:1:thread:1", builder=builder)
|
||||
await cache.get_or_build("user:1:thread:2", builder=builder)
|
||||
await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||
|
||||
removed = cache.invalidate_prefix("user:1:")
|
||||
assert removed == 2
|
||||
assert cache.stats()["size"] == 1
|
||||
|
||||
# The user:2 entry must still be hot (no rebuild).
|
||||
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
|
||||
assert survivor_value is not None
|
||||
|
|
@ -7,7 +7,9 @@ import pytest
|
|||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import (
|
||||
BusyMutexMiddleware,
|
||||
end_turn,
|
||||
get_cancel_event,
|
||||
is_cancel_requested,
|
||||
manager,
|
||||
request_cancel,
|
||||
reset_cancel,
|
||||
|
|
@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
|
|||
def test_reset_cancel_idempotent() -> None:
|
||||
# Should not raise even if event was never created
|
||||
reset_cancel("never-seen")
|
||||
|
||||
|
||||
def test_request_cancel_creates_event_for_unseen_thread() -> None:
|
||||
thread_id = "never-seen-cancel"
|
||||
reset_cancel(thread_id)
|
||||
|
||||
assert request_cancel(thread_id) is True
|
||||
assert get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
||||
thread_id = "forced-end-turn"
|
||||
mw = BusyMutexMiddleware()
|
||||
runtime = _Runtime(thread_id)
|
||||
|
||||
await mw.abefore_agent({}, runtime)
|
||||
assert manager.lock_for(thread_id).locked()
|
||||
|
||||
request_cancel(thread_id)
|
||||
assert is_cancel_requested(thread_id) is True
|
||||
|
||||
end_turn(thread_id)
|
||||
|
||||
assert not manager.lock_for(thread_id).locked()
|
||||
assert not get_cancel_event(thread_id).is_set()
|
||||
assert is_cancel_requested(thread_id) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
|
||||
"""A stale aafter call from attempt A must not unlock attempt B.
|
||||
|
||||
Repro flow:
|
||||
1) attempt A acquires thread lock
|
||||
2) forced end_turn clears A so retry can proceed
|
||||
3) attempt B acquires same thread lock
|
||||
4) stale attempt-A aafter runs late
|
||||
|
||||
Expected: B lock remains held.
|
||||
"""
|
||||
thread_id = "stale-aafter-lock"
|
||||
runtime = _Runtime(thread_id)
|
||||
attempt_a = BusyMutexMiddleware()
|
||||
attempt_b = BusyMutexMiddleware()
|
||||
|
||||
await attempt_a.abefore_agent({}, runtime)
|
||||
lock = manager.lock_for(thread_id)
|
||||
assert lock.locked()
|
||||
|
||||
end_turn(thread_id)
|
||||
assert not lock.locked()
|
||||
|
||||
await attempt_b.abefore_agent({}, runtime)
|
||||
assert lock.locked()
|
||||
|
||||
# Stale cleanup from attempt A must not release attempt B's lock.
|
||||
await attempt_a.aafter_agent({}, runtime)
|
||||
assert lock.locked()
|
||||
|
||||
await attempt_b.aafter_agent({}, runtime)
|
||||
|
|
|
|||
|
|
@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"SURFSENSE_ENABLE_OTEL",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE",
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
|
||||
]:
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_all(monkeypatch)
|
||||
flags = reload_for_tests()
|
||||
assert isinstance(flags, AgentFeatureFlags)
|
||||
assert flags.disable_new_agent_stack is False
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
assert flags.enable_context_editing is True
|
||||
assert flags.enable_compaction_v2 is True
|
||||
assert flags.enable_retry_after is True
|
||||
assert flags.enable_model_fallback is False
|
||||
assert flags.enable_model_call_limit is True
|
||||
assert flags.enable_tool_call_limit is True
|
||||
assert flags.enable_tool_call_repair is True
|
||||
assert flags.enable_doom_loop is True
|
||||
assert flags.enable_permission is True
|
||||
assert flags.enable_busy_mutex is True
|
||||
assert flags.enable_llm_tool_selector is False
|
||||
assert flags.enable_skills is True
|
||||
assert flags.enable_specialized_subagents is True
|
||||
assert flags.enable_kb_planner_runnable is True
|
||||
assert flags.enable_action_log is True
|
||||
assert flags.enable_revert_route is True
|
||||
assert flags.enable_stream_parity_v2 is True
|
||||
assert flags.enable_plugin_loader is False
|
||||
assert flags.enable_otel is False
|
||||
# Phase 2: agent cache is now default-on (the prerequisite tool
|
||||
# ``db_session`` refactor landed). The companion gp-subagent share
|
||||
# flag stays default-off pending data on cold-miss frequency.
|
||||
assert flags.enable_agent_cache is True
|
||||
assert flags.enable_agent_cache_share_gp_subagent is False
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
|
||||
|
||||
def test_master_kill_switch_overrides_individual_flags(
|
||||
|
|
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
|
|||
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
|
||||
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
|
||||
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
|
||||
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
|
||||
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
|
||||
"enable_otel": "SURFSENSE_ENABLE_OTEL",
|
||||
}
|
||||
|
||||
# `enable_otel` is intentionally orthogonal — it does NOT count toward
|
||||
# ``any_new_middleware_enabled`` because OTel is observability-only and
|
||||
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
|
||||
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
|
||||
|
||||
for attr, env_name in flag_to_env.items():
|
||||
_clear_all(monkeypatch)
|
||||
monkeypatch.setenv(env_name, "true")
|
||||
monkeypatch.setenv(env_name, "false")
|
||||
flags = reload_for_tests()
|
||||
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
|
||||
if attr in counts_toward_middleware:
|
||||
assert flags.any_new_middleware_enabled() is True
|
||||
else:
|
||||
assert flags.any_new_middleware_enabled() is False
|
||||
assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for ``FlattenSystemMessageMiddleware``.
|
||||
|
||||
The middleware exists to defend against Anthropic's "Found 5 cache_control
|
||||
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
|
||||
the system message and the OpenRouter→Anthropic adapter redistributes
|
||||
``cache_control`` across all of them. The flattening collapses every
|
||||
all-text system content list to a single string before the LLM call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.agents.new_chat.middleware.flatten_system import (
|
||||
FlattenSystemMessageMiddleware,
|
||||
_flatten_text_blocks,
|
||||
_flattened_request,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_text_blocks — pure helper, the heart of the middleware.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenTextBlocks:
|
||||
def test_joins_text_blocks_with_double_newline(self) -> None:
|
||||
blocks = [
|
||||
{"type": "text", "text": "<surfsense base>"},
|
||||
{"type": "text", "text": "<filesystem section>"},
|
||||
{"type": "text", "text": "<skills section>"},
|
||||
]
|
||||
assert (
|
||||
_flatten_text_blocks(blocks)
|
||||
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
|
||||
)
|
||||
|
||||
def test_handles_single_text_block(self) -> None:
|
||||
blocks = [{"type": "text", "text": "only one"}]
|
||||
assert _flatten_text_blocks(blocks) == "only one"
|
||||
|
||||
def test_handles_empty_list(self) -> None:
|
||||
assert _flatten_text_blocks([]) == ""
|
||||
|
||||
def test_passes_through_bare_string_blocks(self) -> None:
|
||||
# LangChain content can mix bare strings and dict blocks.
|
||||
blocks = ["raw string", {"type": "text", "text": "dict block"}]
|
||||
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
|
||||
|
||||
def test_returns_none_for_image_block(self) -> None:
|
||||
# System messages with images are rare — but we never want to
|
||||
# silently lose the image payload by joining as text.
|
||||
blocks = [
|
||||
{"type": "text", "text": "look at this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
|
||||
]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_for_non_dict_non_str_block(self) -> None:
|
||||
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_when_text_field_missing(self) -> None:
|
||||
blocks = [{"type": "text"}] # no ``text`` key
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_returns_none_when_text_is_not_string(self) -> None:
|
||||
blocks = [{"type": "text", "text": ["nested", "list"]}]
|
||||
assert _flatten_text_blocks(blocks) is None
|
||||
|
||||
def test_drops_cache_control_from_inner_blocks(self) -> None:
|
||||
# The whole point: existing cache_control on inner blocks is
|
||||
# discarded so LiteLLM's ``cache_control_injection_points`` can
|
||||
# re-attach exactly one breakpoint after flattening.
|
||||
blocks = [
|
||||
{"type": "text", "text": "first"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "second",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
]
|
||||
flattened = _flatten_text_blocks(blocks)
|
||||
assert flattened == "first\n\nsecond"
|
||||
assert "cache_control" not in flattened # type: ignore[operator]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flattened_request — decides when to override and when to no-op.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(system_message: SystemMessage | None) -> Any:
|
||||
"""Build a minimal ModelRequest stub. We only need .system_message
|
||||
and .override(system_message=...) — the middleware never touches
|
||||
other fields.
|
||||
"""
|
||||
request = MagicMock()
|
||||
request.system_message = system_message
|
||||
|
||||
def override(**kwargs: Any) -> Any:
|
||||
new_request = MagicMock()
|
||||
new_request.system_message = kwargs.get(
|
||||
"system_message", request.system_message
|
||||
)
|
||||
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
|
||||
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
|
||||
return new_request
|
||||
|
||||
request.override = override
|
||||
return request
|
||||
|
||||
|
||||
class TestFlattenedRequest:
|
||||
def test_collapses_multi_block_system_to_string(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "<base>"},
|
||||
{"type": "text", "text": "<todo>"},
|
||||
{"type": "text", "text": "<filesystem>"},
|
||||
{"type": "text", "text": "<skills>"},
|
||||
{"type": "text", "text": "<subagents>"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
flattened = _flattened_request(request)
|
||||
|
||||
assert flattened is not None
|
||||
assert isinstance(flattened.system_message, SystemMessage)
|
||||
assert flattened.system_message.content == (
|
||||
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
|
||||
)
|
||||
|
||||
def test_no_op_for_string_content(self) -> None:
|
||||
sys = SystemMessage(content="already a string")
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_for_single_block_list(self) -> None:
|
||||
# One block already produces one breakpoint — no need to flatten.
|
||||
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_when_system_message_missing(self) -> None:
|
||||
request = _make_request(None)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_no_op_when_list_contains_non_text_block(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "look"},
|
||||
{"type": "image_url", "image_url": {"url": "data:..."}},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
assert _flattened_request(request) is None
|
||||
|
||||
def test_preserves_additional_kwargs_and_metadata(self) -> None:
|
||||
# Defensive: nothing in the current chain sets these on a system
|
||||
# message, but losing them silently when something does in the
|
||||
# future would be a regression. ``name`` in particular is the only
|
||||
# ``additional_kwargs`` field that ChatLiteLLM's
|
||||
# ``_convert_message_to_dict`` propagates onto the wire.
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
],
|
||||
additional_kwargs={"name": "surfsense_system", "x": 1},
|
||||
response_metadata={"tokens": 42},
|
||||
)
|
||||
sys.id = "sys-msg-1"
|
||||
request = _make_request(sys)
|
||||
|
||||
flattened = _flattened_request(request)
|
||||
assert flattened is not None
|
||||
assert flattened.system_message.content == "a\n\nb"
|
||||
assert flattened.system_message.additional_kwargs == {
|
||||
"name": "surfsense_system",
|
||||
"x": 1,
|
||||
}
|
||||
assert flattened.system_message.response_metadata == {"tokens": 42}
|
||||
assert flattened.system_message.id == "sys-msg-1"
|
||||
|
||||
def test_idempotent_when_run_twice(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
first = _flattened_request(request)
|
||||
assert first is not None
|
||||
|
||||
# Second pass on the already-flattened request should be a no-op.
|
||||
# We re-wrap in a request stub since the helper inspects
|
||||
# ``request.system_message.content``.
|
||||
second_request = _make_request(first.system_message)
|
||||
assert _flattened_request(second_request) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware integration — verify the handler sees a flattened request.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMiddlewareWrap:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_passes_flattened_request_to_handler(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "alpha"},
|
||||
{"type": "text", "text": "beta"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
|
||||
assert result == "ok"
|
||||
assert isinstance(captured["request"].system_message, SystemMessage)
|
||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_passes_through_when_already_string(self) -> None:
|
||||
sys = SystemMessage(content="just a string")
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
await mw.awrap_model_call(request, handler)
|
||||
|
||||
# Same request object: no override happened.
|
||||
assert captured["request"] is request
|
||||
|
||||
def test_sync_passes_flattened_request_to_handler(self) -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "alpha"},
|
||||
{"type": "text", "text": "beta"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
result = mw.wrap_model_call(request, handler)
|
||||
|
||||
assert result == "ok"
|
||||
assert captured["request"].system_message.content == "alpha\n\nbeta"
|
||||
|
||||
def test_sync_passes_through_when_no_system_message(self) -> None:
|
||||
request = _make_request(None)
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def handler(req: Any) -> str:
|
||||
captured["request"] = req
|
||||
return "ok"
|
||||
|
||||
mw = FlattenSystemMessageMiddleware()
|
||||
mw.wrap_model_call(request, handler)
|
||||
assert captured["request"] is request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression guard — pin the worst-case shape that triggered the
|
||||
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
|
||||
# downstream cache_control_injection_points can only place 1 breakpoint
|
||||
# on the system message regardless of provider redistribution quirks.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_regression_five_block_system_collapses_to_one_block() -> None:
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
|
||||
{"type": "text", "text": "<TodoListMiddleware section>"},
|
||||
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
|
||||
{"type": "text", "text": "<SkillsMiddleware section>"},
|
||||
{"type": "text", "text": "<SubAgentMiddleware section>"},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
flattened = _flattened_request(request)
|
||||
|
||||
assert flattened is not None
|
||||
assert isinstance(flattened.system_message.content, str)
|
||||
# The exact join doesn't matter for the cache_control accounting —
|
||||
# only that there is exactly ONE content block when LiteLLM's
|
||||
# AnthropicCacheControlHook later targets ``role: system``.
|
||||
assert "<surfsense base" in flattened.system_message.content
|
||||
assert "<SubAgentMiddleware" in flattened.system_message.content
|
||||
|
||||
|
||||
def test_regression_human_message_not_modified() -> None:
|
||||
# Sanity: the middleware MUST NOT touch user messages — only the
|
||||
# system message. Multi-block user content is the path that carries
|
||||
# image attachments and would lose its image_url block on
|
||||
# accidental flatten.
|
||||
sys = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "a"},
|
||||
{"type": "text", "text": "b"},
|
||||
]
|
||||
)
|
||||
user = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "look at this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
|
||||
]
|
||||
)
|
||||
request = _make_request(sys)
|
||||
request.messages = [user]
|
||||
|
||||
flattened = _flattened_request(request)
|
||||
assert flattened is not None
|
||||
# System flattened to string …
|
||||
assert isinstance(flattened.system_message.content, str)
|
||||
# … user message is untouched (the helper does not even look at it).
|
||||
assert flattened.messages == [user]
|
||||
assert isinstance(user.content, list)
|
||||
assert len(user.content) == 2
|
||||
|
|
@ -27,6 +27,7 @@ class TestDefaultAutoApprovedToolsList:
|
|||
expected = {
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"create_calendar_event",
|
||||
"create_notion_page",
|
||||
"create_confluence_page",
|
||||
"create_google_drive_file",
|
||||
|
|
@ -41,13 +42,12 @@ class TestDefaultAutoApprovedToolsList:
|
|||
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
|
||||
|
||||
def test_send_tools_are_not_auto_approved(self) -> None:
|
||||
# External-broadcast tools must always prompt.
|
||||
# External-broadcast / destructive tools must always prompt.
|
||||
for tool_name in (
|
||||
"send_gmail_email",
|
||||
"send_discord_message",
|
||||
"send_teams_message",
|
||||
"delete_notion_page",
|
||||
"create_calendar_event",
|
||||
"delete_calendar_event",
|
||||
):
|
||||
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,370 @@
|
|||
r"""Tests for ``apply_litellm_prompt_caching`` in
|
||||
:mod:`app.agents.new_chat.prompt_caching`.
|
||||
|
||||
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
|
||||
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
|
||||
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
|
||||
``litellm.completion(...)``. The tests below pin its public contract:
|
||||
|
||||
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
|
||||
savings compound across multi-turn conversations on Anthropic-family
|
||||
providers. ``index: 0`` is used (rather than ``role: system``) because
|
||||
the deepagent stack accumulates multiple ``SystemMessage``\ s in
|
||||
``state["messages"]`` and ``role: system`` would tag every one of
|
||||
them, blowing past Anthropic's 4-block ``cache_control`` cap.
|
||||
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
|
||||
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
|
||||
prompt-cache surface is available).
|
||||
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
|
||||
OpenAI-only kwargs because the router fans out across providers.
|
||||
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
|
||||
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
|
||||
skipped rather than raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test doubles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
"""Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``.
|
||||
|
||||
The helper only inspects ``getattr(llm, "model_kwargs", None)``,
|
||||
``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple
|
||||
object suffices — we don't need to spin up real LangChain/LiteLLM
|
||||
machinery for unit tests of the helper's logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "openai/gpt-4o",
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {}
|
||||
|
||||
|
||||
class ChatLiteLLMRouter:
|
||||
"""Class-name-only impostor of the real router.
|
||||
|
||||
The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"``
|
||||
(a deliberate stringly-typed check to avoid an import cycle with
|
||||
``app.services.llm_router_service``). Reusing the same class name here
|
||||
triggers the same code path without instantiating a real ``Router``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = "auto"
|
||||
self.model_kwargs: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _make_cfg(**overrides: Any) -> AgentConfig:
|
||||
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
|
||||
defaults: dict[str, Any] = {
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "k",
|
||||
}
|
||||
return AgentConfig(**{**defaults, **overrides})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (a) Universal injection points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sets_both_cache_control_injection_points_with_no_config() -> None:
|
||||
"""Bare call (no agent_config, no thread_id) still sets the two
|
||||
universal breakpoints — these cost nothing on providers that don't
|
||||
consume them and unlock caching on every supported provider."""
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm)
|
||||
|
||||
points = llm.model_kwargs["cache_control_injection_points"]
|
||||
assert {"location": "message", "index": 0} in points
|
||||
assert {"location": "message", "index": -1} in points
|
||||
assert len(points) == 2
|
||||
|
||||
|
||||
def test_does_not_inject_role_system_breakpoint() -> None:
|
||||
"""Regression: deliberately AVOID ``role: system`` so we don't tag
|
||||
every SystemMessage the deepagent ``before_agent`` injectors push
|
||||
into ``state["messages"]`` (priority, tree, memory, file-intent,
|
||||
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
|
||||
``cache_control`` cap and surfaces as
|
||||
``OpenrouterException: A maximum of 4 blocks with cache_control may
|
||||
be provided. Found N`` 400s.
|
||||
"""
|
||||
llm = _FakeLLM()
|
||||
apply_litellm_prompt_caching(llm)
|
||||
points = llm.model_kwargs["cache_control_injection_points"]
|
||||
assert all(p.get("role") != "system" for p in points), (
|
||||
f"Expected no role=system breakpoint, got: {points}"
|
||||
)
|
||||
|
||||
|
||||
def test_injection_points_set_for_anthropic_config() -> None:
|
||||
"""Anthropic-family configs need the marker — verify it lands."""
|
||||
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
|
||||
llm = _FakeLLM(model="anthropic/claude-3-5-sonnet")
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg)
|
||||
|
||||
assert "cache_control_injection_points" in llm.model_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (b) Idempotency / user override wins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None:
|
||||
"""Users who set their own injection points (e.g. with ``ttl: "1h"``
|
||||
via ``litellm_params``) keep them — the helper merges, never
|
||||
clobbers."""
|
||||
user_points = [
|
||||
{"location": "message", "role": "system", "ttl": "1h"},
|
||||
]
|
||||
llm = _FakeLLM(
|
||||
model_kwargs={"cache_control_injection_points": user_points},
|
||||
)
|
||||
|
||||
apply_litellm_prompt_caching(llm)
|
||||
|
||||
assert llm.model_kwargs["cache_control_injection_points"] is user_points
|
||||
|
||||
|
||||
def test_idempotent_when_called_multiple_times() -> None:
|
||||
"""Build-time + thread-time double-call must be a no-op the second time."""
|
||||
cfg = _make_cfg(provider="OPENAI")
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
|
||||
snapshot = {
|
||||
"cache_control_injection_points": list(
|
||||
llm.model_kwargs["cache_control_injection_points"]
|
||||
),
|
||||
"prompt_cache_key": llm.model_kwargs["prompt_cache_key"],
|
||||
"prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"],
|
||||
}
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
|
||||
|
||||
assert (
|
||||
llm.model_kwargs["cache_control_injection_points"]
|
||||
== snapshot["cache_control_injection_points"]
|
||||
)
|
||||
assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"]
|
||||
assert (
|
||||
llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"]
|
||||
)
|
||||
|
||||
|
||||
def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
|
||||
"""A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via
|
||||
``litellm_params``) wins over our default per-thread key."""
|
||||
cfg = _make_cfg(provider="OPENAI")
|
||||
llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"})
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||
|
||||
assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
|
||||
def test_sets_openai_family_extras(provider: str) -> None:
|
||||
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
|
||||
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
|
||||
cache TTL beyond the default 5-10 min)."""
|
||||
cfg = _make_cfg(provider=provider)
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||
|
||||
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
|
||||
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||
|
||||
|
||||
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||
"""Without a thread id we can't construct a per-thread key. Retention
|
||||
is still useful so we set it (it's free)."""
|
||||
cfg = _make_cfg(provider="OPENAI")
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
|
||||
|
||||
assert "prompt_cache_key" not in llm.model_kwargs
|
||||
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
|
||||
)
|
||||
def test_no_openai_extras_for_other_providers(provider: str) -> None:
|
||||
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
|
||||
skip it. ``cache_control_injection_points`` is still set (universal)."""
|
||||
cfg = _make_cfg(provider=provider)
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||
|
||||
assert "prompt_cache_key" not in llm.model_kwargs
|
||||
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||
assert "cache_control_injection_points" in llm.model_kwargs
|
||||
|
||||
|
||||
def test_no_openai_extras_in_auto_mode() -> None:
|
||||
"""Auto-mode fans out across mixed providers — we can't statically
|
||||
target OpenAI-only kwargs."""
|
||||
cfg = AgentConfig.from_auto_mode()
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||
|
||||
assert "prompt_cache_key" not in llm.model_kwargs
|
||||
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||
assert "cache_control_injection_points" in llm.model_kwargs
|
||||
|
||||
|
||||
def test_no_openai_extras_for_custom_provider() -> None:
|
||||
"""Custom providers route through arbitrary user-supplied prefixes —
|
||||
we don't try to infer OpenAI-family compatibility."""
|
||||
cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy")
|
||||
llm = _FakeLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||
|
||||
assert "prompt_cache_key" not in llm.model_kwargs
|
||||
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (d) ChatLiteLLMRouter — universal injection points only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_router_llm_gets_only_universal_injection_points() -> None:
|
||||
"""Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must
|
||||
receive only the universal injection points — its requests dispatch
|
||||
across provider deployments and OpenAI-only kwargs would be wasted
|
||||
(or stripped by ``drop_params``) on non-OpenAI legs."""
|
||||
router = ChatLiteLLMRouter()
|
||||
cfg = _make_cfg(provider="OPENAI")
|
||||
|
||||
apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42)
|
||||
|
||||
assert "cache_control_injection_points" in router.model_kwargs
|
||||
assert "prompt_cache_key" not in router.model_kwargs
|
||||
assert "prompt_cache_retention" not in router.model_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (e) Defensive paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_handles_llm_with_no_writable_model_kwargs() -> None:
|
||||
"""Some LLM implementations (e.g. fakes / minimal subclasses) don't
|
||||
expose a writable ``model_kwargs``. The helper must skip silently —
|
||||
raising would crash the entire LLM build path on a non-critical
|
||||
optimisation."""
|
||||
|
||||
class _ImmutableLLM:
|
||||
# ``__slots__`` blocks attribute creation, so ``setattr`` raises.
|
||||
__slots__ = ("model",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = "openai/gpt-4o"
|
||||
|
||||
llm = _ImmutableLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm)
|
||||
|
||||
|
||||
def test_initialises_missing_model_kwargs_dict() -> None:
|
||||
"""When ``model_kwargs`` is present-but-None (Pydantic v2 default
|
||||
pattern when no factory is set), the helper initialises it to an
|
||||
empty dict before mutating."""
|
||||
|
||||
class _LazyLLM:
|
||||
def __init__(self) -> None:
|
||||
self.model = "openai/gpt-4o"
|
||||
self.model_kwargs: dict[str, Any] | None = None
|
||||
|
||||
llm = _LazyLLM()
|
||||
|
||||
apply_litellm_prompt_caching(llm)
|
||||
|
||||
assert isinstance(llm.model_kwargs, dict)
|
||||
assert "cache_control_injection_points" in llm.model_kwargs
|
||||
|
||||
|
||||
def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None:
|
||||
"""Direct caller path (e.g. ``create_chat_litellm_from_config`` for
|
||||
YAML configs without a structured ``AgentConfig``): without
|
||||
``agent_config`` the helper sets only the universal injection points
|
||||
— no OpenAI-family extras even if the prefix says ``openai/``.
|
||||
Conservative: we'd rather miss the speedup than silently misroute."""
|
||||
llm = _FakeLLM(model="openai/gpt-4o")
|
||||
|
||||
apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99)
|
||||
|
||||
assert "cache_control_injection_points" in llm.model_kwargs
|
||||
assert "prompt_cache_key" not in llm.model_kwargs
|
||||
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# (f) drop_params safety net (regression guard for #19346)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_litellm_drop_params_is_globally_enabled() -> None:
|
||||
"""``litellm.drop_params=True`` is set globally in
|
||||
:mod:`app.services.llm_service` so any ``prompt_cache_key`` /
|
||||
``prompt_cache_retention`` we set on an OpenAI-family config is
|
||||
auto-stripped if the request later routes to a non-supporting
|
||||
provider (e.g. via auto-mode router fallback). This test pins that
|
||||
invariant — losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
import app.services.llm_service # noqa: F401 (side-effect: sets globals)
|
||||
|
||||
assert litellm.drop_params is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression note: LiteLLM #15696 (multi-content-block last message)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]``
|
||||
# would get ``cache_control`` applied to *every* content block instead
|
||||
# of only the last one — wasting cache breakpoints and triggering 400s
|
||||
# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in
|
||||
# https://github.com/BerriAI/litellm/pull/15699.
|
||||
#
|
||||
# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix).
|
||||
# An end-to-end behavioural test would need to run ``litellm.completion``
|
||||
# through the Anthropic transformer, which is integration territory and
|
||||
# better covered by LiteLLM's own test suite. The unit guard here is the
|
||||
# version pin plus the build-time ``model_kwargs`` shape we verify above.
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`.
|
||||
|
||||
The helper picks the model id fed to ``detect_provider_variant`` so the
|
||||
right ``<provider_hints>`` block lands in the system prompt. The tests
|
||||
below pin its preference order:
|
||||
|
||||
1. ``agent_config.litellm_params["base_model"]`` (Azure-correct).
|
||||
2. ``agent_config.model_name``.
|
||||
3. ``getattr(llm, "model", None)``.
|
||||
|
||||
Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would
|
||||
silently miss every provider regex.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_cfg(**overrides) -> AgentConfig:
|
||||
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
|
||||
defaults = {
|
||||
"provider": "OPENAI",
|
||||
"model_name": "x",
|
||||
"api_key": "k",
|
||||
}
|
||||
return AgentConfig(**{**defaults, **overrides})
|
||||
|
||||
|
||||
class _FakeLLM:
|
||||
"""Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance.
|
||||
|
||||
The resolver only reads the ``.model`` attribute via ``getattr``,
|
||||
matching the established idiom in ``knowledge_search.py`` /
|
||||
``stream_new_chat.py`` / ``document_summarizer.py``.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | None) -> None:
|
||||
self.model = model
|
||||
|
||||
|
||||
def test_prefers_litellm_params_base_model_over_deployment_name() -> None:
|
||||
"""Azure deployment slug must NOT shadow the underlying model family.
|
||||
|
||||
This is the failure mode the helper exists to prevent: a deployment
|
||||
named ``"azure/prod-chat-001"`` would not match any provider regex
|
||||
on its own, but the family ``"gpt-4o"`` lives in
|
||||
``litellm_params["base_model"]`` and routes to ``openai_classic``.
|
||||
"""
|
||||
cfg = _make_cfg(
|
||||
model_name="azure/prod-chat-001",
|
||||
litellm_params={"base_model": "gpt-4o"},
|
||||
)
|
||||
assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o"
|
||||
|
||||
|
||||
def test_falls_back_to_model_name_when_litellm_params_is_none() -> None:
|
||||
cfg = _make_cfg(
|
||||
model_name="anthropic/claude-3-5-sonnet",
|
||||
litellm_params=None,
|
||||
)
|
||||
got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet"))
|
||||
assert got == "anthropic/claude-3-5-sonnet"
|
||||
|
||||
|
||||
def test_handles_litellm_params_without_base_model_key() -> None:
|
||||
cfg = _make_cfg(
|
||||
model_name="openai/gpt-4o",
|
||||
litellm_params={"temperature": 0.5},
|
||||
)
|
||||
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
|
||||
|
||||
|
||||
def test_ignores_blank_base_model() -> None:
|
||||
"""Whitespace-only ``base_model`` must not shadow ``model_name``."""
|
||||
cfg = _make_cfg(
|
||||
model_name="openai/gpt-4o",
|
||||
litellm_params={"base_model": " "},
|
||||
)
|
||||
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
|
||||
|
||||
|
||||
def test_ignores_non_string_base_model() -> None:
|
||||
"""Defensive: a non-string ``base_model`` should not crash the resolver."""
|
||||
cfg = _make_cfg(
|
||||
model_name="openai/gpt-4o",
|
||||
litellm_params={"base_model": 42},
|
||||
)
|
||||
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
|
||||
|
||||
|
||||
def test_falls_back_to_llm_model_when_no_agent_config() -> None:
|
||||
"""No ``agent_config`` -> use ``llm.model`` directly. Defensive path
|
||||
for direct callers; production callers always supply a config."""
|
||||
assert (
|
||||
_resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini"))
|
||||
== "openai/gpt-4o-mini"
|
||||
)
|
||||
|
||||
|
||||
def test_returns_none_when_nothing_available() -> None:
|
||||
"""``compose_system_prompt`` treats ``None`` as the ``"default"``
|
||||
variant and emits no provider block."""
|
||||
assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None
|
||||
|
||||
|
||||
def test_auto_mode_resolves_to_auto_string() -> None:
|
||||
"""Auto mode -> ``"auto"``. ``detect_provider_variant("auto")``
|
||||
returns ``"default"``, which is correct: the child model isn't
|
||||
known until the LiteLLM Router dispatches."""
|
||||
cfg = AgentConfig.from_auto_mode()
|
||||
assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto"
|
||||
|
|
@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
|
|||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
|
||||
# ── mentioned_document_ids cross-turn drain ────────────────────────────
|
||||
|
||||
|
||||
class TestKnowledgePriorityMentionDrain:
|
||||
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
|
||||
|
||||
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
|
||||
instance across turns of the same thread. ``mentioned_document_ids``
|
||||
can therefore enter the middleware via two paths:
|
||||
|
||||
1. The constructor closure (``__init__(mentioned_document_ids=...)``) —
|
||||
seeded by the cache-miss build on turn 1.
|
||||
2. ``runtime.context.mentioned_document_ids`` — supplied freshly per
|
||||
turn by the streaming task.
|
||||
|
||||
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
|
||||
on turn 2 would fall through to the closure (because ``[]`` is falsy in
|
||||
Python) and replay turn 1's mentions. This class pins down the
|
||||
correct behaviour: the runtime path is authoritative even when empty,
|
||||
and the closure is drained the first time the runtime path fires so
|
||||
no later turn can ever resurrect stale state.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_runtime(mention_ids: list[int]):
|
||||
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
context=SimpleNamespace(mentioned_document_ids=mention_ids),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _planner_llm() -> "FakeLLM":
|
||||
# Planner returns a stable, non-recency plan so we always land in
|
||||
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
|
||||
# invoked alongside the main search).
|
||||
return FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "follow up question",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
|
||||
"""Turn 1 with mentions in BOTH closure and runtime context: the
|
||||
runtime path wins AND the closure is drained so a future turn
|
||||
cannot replay it.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what is in those docs?")]},
|
||||
runtime=self._make_runtime([1, 2, 3]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [[1, 2, 3]], (
|
||||
"runtime.context mentions must be the source of truth on turn 1"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == [], (
|
||||
"closure must be drained the first time the runtime path fires "
|
||||
"so no later turn can replay stale mentions"
|
||||
)
|
||||
|
||||
async def test_empty_runtime_context_does_not_replay_closure_mentions(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Regression: turn 2 with NO mentions must not surface turn 1's
|
||||
mentions from the constructor closure.
|
||||
|
||||
Before the fix, ``if ctx_mentions:`` treated an empty list as
|
||||
absent and fell through to ``elif self.mentioned_document_ids:``,
|
||||
replaying turn 1's mentions. This test pins down the corrected
|
||||
behaviour.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
# Simulate a cached middleware instance whose closure was seeded
|
||||
# by a previous turn's cache-miss build (mentions=[1,2,3]).
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
# Turn 2: streaming task supplies an EMPTY mention list (no
|
||||
# mentions on this follow-up turn).
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what about the next steps?")]},
|
||||
runtime=self._make_runtime([]),
|
||||
)
|
||||
|
||||
assert fetched_ids == [], (
|
||||
"fetch_mentioned_documents must NOT be called when the runtime "
|
||||
"context says there are no mentions for this turn"
|
||||
)
|
||||
|
||||
async def test_legacy_path_fires_only_when_runtime_context_absent(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Backward-compat: if a caller doesn't supply runtime.context (old
|
||||
non-streaming code path), the closure-injected mentions are still
|
||||
honoured exactly once and then drained.
|
||||
"""
|
||||
fetched_ids: list[list[int]] = []
|
||||
|
||||
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
|
||||
fetched_ids.append(list(document_ids))
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**_kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
|
||||
fake_fetch_mentioned_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
|
||||
middleware = KnowledgeBaseSearchMiddleware(
|
||||
llm=self._planner_llm(),
|
||||
search_space_id=42,
|
||||
mentioned_document_ids=[7, 8],
|
||||
)
|
||||
|
||||
# First call: no runtime → legacy path uses the closure.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="initial question")]},
|
||||
runtime=None,
|
||||
)
|
||||
# Second call: still no runtime — closure already drained, so no replay.
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="follow up")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert fetched_ids == [[7, 8]], (
|
||||
"legacy path must honour the closure exactly once and then drain it"
|
||||
)
|
||||
assert middleware.mentioned_document_ids == []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
|
||||
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
|
||||
|
||||
There is no DB column for ``supports_image_input`` on
|
||||
``NewLLMConfig`` — the value is resolved at the API boundary by
|
||||
``derive_supports_image_input`` so the new-chat selector / streaming
|
||||
task can read the same field shape regardless of source (BYOK vs YAML
|
||||
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
|
||||
user out of their own model choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.db import LiteLLMProvider
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _byok_row(
|
||||
*,
|
||||
id_: int,
|
||||
model_name: str,
|
||||
base_model: str | None = None,
|
||||
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
|
||||
custom_provider: str | None = None,
|
||||
) -> object:
|
||||
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
|
||||
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
|
||||
|
||||
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
|
||||
enum validator accepts it — same as the ORM row would carry."""
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
name=f"BYOK-{id_}",
|
||||
description=None,
|
||||
provider=provider,
|
||||
custom_provider=custom_provider,
|
||||
model_name=model_name,
|
||||
api_key="sk-byok",
|
||||
api_base=None,
|
||||
litellm_params={"base_model": base_model} if base_model else None,
|
||||
system_instructions="",
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
created_at=datetime.now(tz=UTC),
|
||||
search_space_id=42,
|
||||
user_id=uuid4(),
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_byok_known_vision_model_resolves_true():
|
||||
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
|
||||
True. The serialized row carries that value through to the
|
||||
``NewLLMConfigRead`` schema."""
|
||||
row = _byok_row(id_=1, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
assert serialized.id == 1
|
||||
assert serialized.model_name == "gpt-4o"
|
||||
|
||||
|
||||
def test_serialize_byok_unknown_model_default_allows():
|
||||
"""Unknown / unmapped: default-allow. The streaming-task safety net
|
||||
is the actual block, and it requires LiteLLM to *explicitly* say
|
||||
text-only — so a brand new BYOK model should not be pre-judged."""
|
||||
row = _byok_row(
|
||||
id_=2,
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
provider=LiteLLMProvider.CUSTOM,
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_uses_base_model_when_present():
|
||||
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
|
||||
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
|
||||
helper must consult ``base_model`` first or unrecognised deployment
|
||||
ids would shadow the real capability."""
|
||||
row = _byok_row(
|
||||
id_=3,
|
||||
model_name="my-azure-deployment-id-no-litellm-knows-this",
|
||||
base_model="gpt-4o",
|
||||
provider=LiteLLMProvider.AZURE_OPENAI,
|
||||
)
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
|
||||
assert serialized.supports_image_input is True
|
||||
|
||||
|
||||
def test_serialize_byok_returns_pydantic_read_model():
|
||||
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
|
||||
the schema additions are guaranteed to be present in the API
|
||||
surface. This guards against a future regression where someone
|
||||
deletes the augmentation step and falls back to ORM passthrough."""
|
||||
from app.schemas import NewLLMConfigRead
|
||||
|
||||
row = _byok_row(id_=4, model_name="gpt-4o")
|
||||
serialized = new_llm_config_routes._serialize_byok_config(row)
|
||||
assert isinstance(serialized, NewLLMConfigRead)
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
"""Unit tests for ``is_premium`` derivation on the global image-gen and
|
||||
vision-LLM list endpoints.
|
||||
|
||||
Chat globals (``GET /global-llm-configs``) already emit
|
||||
``is_premium = (billing_tier == "premium")``. Image and vision did not,
|
||||
which made the new-chat ``model-selector`` render the Free/Premium badge
|
||||
on the Chat tab but skip it on the Image and Vision tabs (the selector
|
||||
keys its badge logic off ``is_premium``). These tests pin parity:
|
||||
|
||||
* YAML free entry → ``is_premium=False``
|
||||
* YAML premium entry → ``is_premium=True``
|
||||
* OpenRouter dynamic premium entry → ``is_premium=True``
|
||||
* Auto stub (always emitted when at least one config is present)
|
||||
→ ``is_premium=False``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_IMAGE_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "DALL-E 3",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "dall-e-3",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-Image 1 (premium)",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -20_001,
|
||||
"name": "google/gemini-2.5-flash-image (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
_VISION_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o Vision",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "Claude 3.5 Sonnet (premium)",
|
||||
"provider": "ANTHROPIC",
|
||||
"model_name": "claude-3-5-sonnet",
|
||||
"api_key": "sk-ant-test",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
{
|
||||
"id": -30_001,
|
||||
"name": "openai/gpt-4o (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "premium",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image generation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
|
||||
"""Each emitted config must carry ``is_premium`` derived server-side
|
||||
from ``billing_tier``. The Auto stub is always free.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub is always emitted when at least one global config exists,
|
||||
# and it must always declare itself free (Auto-mode billing-tier
|
||||
# surfacing is a separate follow-up).
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
# YAML free entry — ``is_premium=False``
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
# YAML premium entry — ``is_premium=True``
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
# OpenRouter dynamic premium entry — same field, same derivation
|
||||
assert by_id[-20_001]["is_premium"] is True
|
||||
assert by_id[-20_001]["billing_tier"] == "premium"
|
||||
|
||||
# Every emitted dict (including Auto) must have the field — never missing.
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
"""When there are no global configs at all, the endpoint emits an
|
||||
empty list (no Auto stub) — Auto mode would have nothing to route to.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
|
||||
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
|
||||
assert payload == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vision LLM
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
|
||||
)
|
||||
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
|
||||
assert by_id[0]["is_premium"] is False
|
||||
assert by_id[0]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-1]["is_premium"] is False
|
||||
assert by_id[-1]["billing_tier"] == "free"
|
||||
|
||||
assert by_id[-2]["is_premium"] is True
|
||||
assert by_id[-2]["billing_tier"] == "premium"
|
||||
|
||||
assert by_id[-30_001]["is_premium"] is True
|
||||
assert by_id[-30_001]["billing_tier"] == "premium"
|
||||
|
||||
for cfg in payload:
|
||||
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
|
||||
assert isinstance(cfg["is_premium"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import vision_llm_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
|
||||
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
|
||||
assert payload == []
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""Unit tests for ``supports_image_input`` derivation on the chat global
|
||||
config endpoint (``GET /global-new-llm-configs``).
|
||||
|
||||
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
|
||||
|
||||
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
|
||||
loader for operator overrides, or by the OpenRouter integration from
|
||||
``architecture.input_modalities``) — wins.
|
||||
2. ``derive_supports_image_input`` helper — default-allow on unknown
|
||||
models, only False when LiteLLM / OR modalities are definitive.
|
||||
|
||||
The flag is purely informational at the API boundary. The streaming
|
||||
task safety net (``is_known_text_only_chat_model``) is the actual block,
|
||||
and it requires LiteLLM to *explicitly* mark the model as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_FIXTURE: list[dict] = [
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (explicit true)",
|
||||
"description": "vision-capable, explicit YAML override",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"name": "DeepSeek V3 (explicit false)",
|
||||
"description": "OpenRouter dynamic — modality-derived false",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "deepseek/deepseek-v3.2-exp",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"billing_tier": "free",
|
||||
"supports_image_input": False,
|
||||
},
|
||||
{
|
||||
"id": -10_010,
|
||||
"name": "Unannotated GPT-4o",
|
||||
"description": "no flag set — resolver should derive True via LiteLLM",
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
# supports_image_input intentionally absent
|
||||
},
|
||||
{
|
||||
"id": -10_011,
|
||||
"name": "Unannotated unknown model",
|
||||
"description": "unmapped — default-allow True",
|
||||
"provider": "CUSTOM",
|
||||
"custom_provider": "brand_new_proxy",
|
||||
"model_name": "brand-new-model-x9",
|
||||
"api_key": "sk-test",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
|
||||
"""Each emitted chat config carries ``supports_image_input`` as a
|
||||
bool. Explicit values win; unannotated entries are resolved via the
|
||||
helper (default-allow True)."""
|
||||
from app.config import config
|
||||
from app.routes import new_llm_config_routes
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
|
||||
|
||||
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
|
||||
by_id = {c["id"]: c for c in payload}
|
||||
|
||||
# Auto stub: optimistic True so the user can keep Auto selected with
|
||||
# vision-capable deployments somewhere in the pool.
|
||||
assert 0 in by_id, "Auto stub should be emitted when configs exist"
|
||||
assert by_id[0]["supports_image_input"] is True
|
||||
assert by_id[0]["is_auto_mode"] is True
|
||||
|
||||
# Explicit True is preserved.
|
||||
assert by_id[-1]["supports_image_input"] is True
|
||||
|
||||
# Explicit False is preserved (the exact failure mode the safety net
|
||||
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
|
||||
# endpoints found that support image input").
|
||||
assert by_id[-2]["supports_image_input"] is False
|
||||
|
||||
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
|
||||
assert by_id[-10_010]["supports_image_input"] is True
|
||||
|
||||
# Unknown / unmapped model: default-allow rather than pre-judge.
|
||||
assert by_id[-10_011]["supports_image_input"] is True
|
||||
|
||||
for cfg in payload:
|
||||
assert "supports_image_input" in cfg, (
|
||||
f"supports_image_input missing from {cfg.get('id')}"
|
||||
)
|
||||
assert isinstance(cfg["supports_image_input"], bool)
|
||||
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
138
surfsense_backend/tests/unit/routes/test_image_gen_quota.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Unit tests for the image-generation route's billing-resolution helper.
|
||||
|
||||
End-to-end "POST /image-generations returns 402" coverage requires the
|
||||
integration harness (real DB, real auth) and lives in
|
||||
``tests/integration/document_upload/`` alongside the other quota tests.
|
||||
This unit test focuses on the new ``_resolve_billing_for_image_gen``
|
||||
helper which:
|
||||
|
||||
* Returns ``free`` for Auto mode, even when premium configs exist
|
||||
(Auto-mode billing-tier surfacing is a follow-up).
|
||||
* Returns ``free`` for user-owned BYOK configs (positive IDs).
|
||||
* Returns the global config's ``billing_tier`` for negative IDs.
|
||||
* Honours the per-config ``quota_reserve_micros`` override when present.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_auto_mode(monkeypatch):
|
||||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, # Not consumed on this code path.
|
||||
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
|
||||
search_space=search_space,
|
||||
)
|
||||
assert tier == "free"
|
||||
assert model == "auto"
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_premium_global_config(monkeypatch):
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
"quota_reserve_micros": 75_000,
|
||||
},
|
||||
{
|
||||
"id": -2,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash-image",
|
||||
"billing_tier": "free",
|
||||
},
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
|
||||
# Premium with override.
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=-1, search_space=search_space
|
||||
)
|
||||
assert tier == "premium"
|
||||
assert model == "openai/gpt-image-1"
|
||||
assert reserve == 75_000
|
||||
|
||||
# Free, no override → falls back to default.
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=-2, search_space=search_space
|
||||
)
|
||||
assert tier == "free"
|
||||
# Provider-prefixed model string for OpenRouter.
|
||||
assert "google/gemini-2.5-flash-image" in model
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_for_user_owned_byok_is_free():
|
||||
"""User-owned BYOK configs (positive IDs) cost the user nothing on
|
||||
our side — they pay the provider directly. Always free.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=None)
|
||||
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=42, search_space=search_space
|
||||
)
|
||||
assert tier == "free"
|
||||
assert model == "user_byok"
|
||||
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
|
||||
"""When the request omits ``image_generation_config_id``, the helper
|
||||
must consult the search space's default — so a search space pinned
|
||||
to a premium global config still gates new requests by quota.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_IMAGE_GEN_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -7,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-image-1",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
search_space = SimpleNamespace(image_generation_config_id=-7)
|
||||
(
|
||||
tier,
|
||||
model,
|
||||
_reserve,
|
||||
) = await image_generation_routes._resolve_billing_for_image_gen(
|
||||
session=None, config_id=None, search_space=search_space
|
||||
)
|
||||
assert tier == "premium"
|
||||
assert model == "openai/gpt-image-1"
|
||||
|
|
@ -0,0 +1,436 @@
|
|||
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||
|
||||
Validates the resolver used by Celery podcast/video tasks to compute
|
||||
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||
prevents Auto-mode podcast/video from leaking premium credit.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||
global → returns ``("premium", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||
global → returns ``("free", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||
→ always ``"free"``.
|
||||
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||
hitting the pin service.
|
||||
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||
``billing_tier``.
|
||||
* Positive id (user BYOK) → always ``"free"``.
|
||||
* Search space not found → raises ``ValueError``.
|
||||
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Tiny AsyncSession stub.
|
||||
|
||||
``responses`` is a list of objects to return from successive
|
||||
``execute()`` calls (in order). The resolver makes at most two
|
||||
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||
lookup), so two queued responses cover the matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: list):
|
||||
self._responses = list(responses)
|
||||
|
||||
async def execute(self, _stmt):
|
||||
if not self._responses:
|
||||
return _FakeExecResult(None)
|
||||
return _FakeExecResult(self._responses.pop(0))
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakePinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str = "premium"
|
||||
from_existing_pin: bool = False
|
||||
|
||||
|
||||
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=42,
|
||||
agent_llm_id=agent_llm_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_byok_config(
|
||||
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
model_name=model_name,
|
||||
litellm_params={"base_model": base_model} if base_model else {},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||
resolver returns ``("premium", <base_model>)``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
# Mock the pin service to return a concrete premium config id.
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
assert selected_llm_config_id == 0
|
||||
assert thread_id == 99
|
||||
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||
|
||||
# Mock global config lookup to return a premium entry.
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -1:
|
||||
return {
|
||||
"id": -1,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
return None
|
||||
|
||||
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||
# imported names resolve to our fakes.
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||
"""Auto + thread → pin returns negative-id free config → resolver
|
||||
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||
out-of-credit users (graceful degradation)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -3:
|
||||
return {
|
||||
"id": -3,
|
||||
"model_name": "openrouter/free-model",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||
}
|
||||
return None
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/free-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||
returns ``("free", ...)`` (BYOK is always free per
|
||||
``AgentConfig.from_new_llm_config``)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||
)
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3-haiku"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||
the pin service. Forward-compat fallback for any future direct-API
|
||||
entrypoint that doesn't have a chat thread."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||
"""If the pin service raises ``ValueError`` (thread missing /
|
||||
mismatched search space), the resolver should log and return free
|
||||
rather than killing the whole task."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(*args, **kwargs):
|
||||
raise ValueError("thread missing")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||
return its ``billing_tier``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "openrouter/some-free",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/some-free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||
"""When the global config has no ``litellm_params.base_model``, the
|
||||
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "fallback-model",
|
||||
"billing_tier": "premium",
|
||||
# No litellm_params.
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
_, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert tier == "premium"
|
||||
assert base_model == "fallback-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_is_always_free():
|
||||
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||
regardless of underlying provider tier."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||
"""If the BYOK config row is missing/deleted but the search space still
|
||||
points at it, the resolver still returns free (no debit) with an empty
|
||||
base_model — billable_call's premium path is skipped, no harm done."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_space_not_found_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
session = _FakeSession([None])
|
||||
|
||||
with pytest.raises(ValueError, match="Search space"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_llm_id_none_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||
|
||||
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||
1026
surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
Normal file
1026
surfsense_backend/tests/unit/services/test_auto_model_pin_service.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,286 @@
|
|||
"""Image-aware extension of the Auto-pin resolver.
|
||||
|
||||
When the current chat turn carries an ``image_url`` block, the pin
|
||||
resolver must:
|
||||
|
||||
1. Filter the candidate pool to vision-capable cfgs so a freshly
|
||||
selected pin can never be text-only.
|
||||
2. Treat any existing pin whose capability is False as invalid (force
|
||||
re-pin), even when it would otherwise be reused as the thread's
|
||||
stable model.
|
||||
3. Raise ``ValueError`` (mapped to the friendly
|
||||
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
|
||||
task) when no vision-capable cfg is available — instead of silently
|
||||
pinning text-only and 404-ing at the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.auto_model_pin_service import (
|
||||
clear_healthy,
|
||||
clear_runtime_cooldown,
|
||||
resolve_or_get_pinned_llm_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_caches():
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
yield
|
||||
clear_runtime_cooldown()
|
||||
clear_healthy()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeQuotaResult:
|
||||
allowed: bool
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, thread):
|
||||
self._thread = thread
|
||||
|
||||
def unique(self):
|
||||
return self
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._thread
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, thread):
|
||||
self.thread = thread
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self.thread)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
|
||||
def _thread(*, pinned: int | None = None):
|
||||
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
|
||||
|
||||
|
||||
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"vision-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
"supports_image_input": True,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
|
||||
return {
|
||||
"id": id_,
|
||||
"provider": "OPENAI",
|
||||
"model_name": f"text-{id_}",
|
||||
"api_key": "k",
|
||||
"billing_tier": tier,
|
||||
# Higher quality than the vision cfgs — so a bug that ignores
|
||||
# the image flag would surface as the resolver picking this one.
|
||||
"supports_image_input": False,
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": quality,
|
||||
}
|
||||
|
||||
|
||||
async def _premium_allowed(*_args, **_kwargs):
|
||||
return _FakeQuotaResult(allowed=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
# The thread should be pinned to the vision cfg even though the
|
||||
# text-only cfg has a higher quality score.
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
|
||||
"""An existing text-only pin must be invalidated when the next turn
|
||||
requires image input. The non-image path would happily reuse it."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-1))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is False
|
||||
assert session.thread.pinned_llm_config_id == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
|
||||
"""If the thread is already pinned to a vision-capable cfg, reuse it
|
||||
— same as the non-image path. Image-aware filtering must not force
|
||||
spurious re-pins."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread(pinned=-2))
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
assert result.resolved_llm_config_id == -2
|
||||
assert result.from_existing_pin is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
|
||||
"""The friendly-error path: no vision-capable cfg in the pool -> raise
|
||||
``ValueError`` whose message contains ``vision-capable`` so the
|
||||
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1), _text_only_cfg(-2)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="vision-capable"):
|
||||
await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
|
||||
"""Regression guard: the image flag must default False and not affect
|
||||
a normal text-only turn — text-only cfgs remain selectable."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[_text_only_cfg(-1)],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
|
||||
"""A YAML cfg that omits ``supports_image_input`` falls through to
|
||||
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
|
||||
that returns True, so the cfg should be a valid candidate."""
|
||||
from app.config import config
|
||||
|
||||
session = _FakeSession(_thread())
|
||||
cfg_unannotated_vision = {
|
||||
"id": -2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o", # known vision model in LiteLLM map
|
||||
"api_key": "k",
|
||||
"billing_tier": "free",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score": 80,
|
||||
# NOTE: no supports_image_input key
|
||||
}
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
|
||||
monkeypatch.setattr(
|
||||
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||
_premium_allowed,
|
||||
)
|
||||
|
||||
result = await resolve_or_get_pinned_llm_config_id(
|
||||
session,
|
||||
thread_id=1,
|
||||
search_space_id=10,
|
||||
user_id=None,
|
||||
selected_llm_config_id=0,
|
||||
requires_image_input=True,
|
||||
)
|
||||
assert result.resolved_llm_config_id == -2
|
||||
559
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
559
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
|
|
@ -0,0 +1,559 @@
|
|||
"""Unit tests for the ``billable_call`` async context manager.
|
||||
|
||||
Covers the per-call premium-credit lifecycle for image generation and
|
||||
vision LLM extraction:
|
||||
|
||||
* Free configs bypass reserve/finalize but still write an audit row.
|
||||
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
|
||||
route layer).
|
||||
* Successful premium calls reserve, yield the accumulator, then finalize
|
||||
with the LiteLLM-reported actual cost — and write an audit row.
|
||||
* Failed premium calls release the reservation so credit isn't leaked.
|
||||
* All quota DB ops happen inside their OWN ``shielded_async_session``,
|
||||
isolating them from the caller's transaction (issue A).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeQuotaResult:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
allowed: bool,
|
||||
used: int = 0,
|
||||
limit: int = 5_000_000,
|
||||
remaining: int = 5_000_000,
|
||||
) -> None:
|
||||
self.allowed = allowed
|
||||
self.used = used
|
||||
self.limit = limit
|
||||
self.remaining = remaining
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stub — record commits for assertion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.added: list[Any] = []
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
s = _FakeSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
|
||||
_SESSIONS_USED: list[_FakeSession] = []
|
||||
|
||||
|
||||
def _patch_isolation_layer(
|
||||
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
|
||||
):
|
||||
"""Wire fake reserve/finalize/release/session helpers."""
|
||||
_SESSIONS_USED.clear()
|
||||
reserve_calls: list[dict[str, Any]] = []
|
||||
finalize_calls: list[dict[str, Any]] = []
|
||||
release_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
|
||||
reserve_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"reserve_micros": reserve_micros,
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
return reserve_result
|
||||
|
||||
async def _fake_finalize(
|
||||
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||
):
|
||||
if finalize_exc is not None:
|
||||
raise finalize_exc
|
||||
finalize_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"actual_micros": actual_micros,
|
||||
"reserved_micros": reserved_micros,
|
||||
}
|
||||
)
|
||||
return finalize_result or _FakeQuotaResult(allowed=True)
|
||||
|
||||
async def _fake_release(*, db_session, user_id, reserved_micros):
|
||||
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
|
||||
|
||||
record_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_record(session, **kwargs):
|
||||
record_calls.append(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_reserve",
|
||||
_fake_reserve,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_finalize",
|
||||
_fake_finalize,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_release",
|
||||
_fake_release,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_fake_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"reserve": reserve_calls,
|
||||
"finalize": finalize_calls,
|
||||
"release": release_calls,
|
||||
"record": record_calls,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# Simulate a captured cost — the accumulator is fed by the LiteLLM
|
||||
# callback in real life, here we add it manually.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Free still audits.
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["record"][0]["usage_type"] == "image_generation"
|
||||
assert spies["record"][0]["cost_micros"] == 37_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "image_generation"
|
||||
assert err.used_micros == 5_000_000
|
||||
assert err.limit_micros == 5_000_000
|
||||
assert err.remaining_micros == 0
|
||||
# Reserve was attempted, but no finalize/release on a denied reserve
|
||||
# — the reservation never actually held credit.
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Denied premium calls do NOT create an audit row (no work happened).
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# LiteLLM callback would normally fill this — simulate $0.04 image.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["reserve"][0]["reserve_micros"] == 50_000
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert spies["finalize"][0]["actual_micros"] == 40_000
|
||||
assert spies["finalize"][0]["reserved_micros"] == 50_000
|
||||
assert spies["release"] == []
|
||||
# And audit row written with the actual debited cost.
|
||||
assert spies["record"][0]["cost_micros"] == 40_000
|
||||
# Each quota op opened its OWN session — proves session isolation.
|
||||
assert len(_SESSIONS_USED) >= 3
|
||||
# Sessions used should each have committed (or be the audit one which commits).
|
||||
for _s in _SESSIONS_USED:
|
||||
# finalize/reserve happen via TokenQuotaService.* which we stub —
|
||||
# they don't actually call commit on our fake session, but the
|
||||
# audit session does. We just assert >=1 session committed.
|
||||
pass
|
||||
assert any(s.committed for s in _SESSIONS_USED)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_failure_releases_reservation(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _ProviderError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_ProviderError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
raise _ProviderError("OpenRouter 503")
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
# Failure path: release the held reservation.
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["release"][0]["reserved_micros"] == 50_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
||||
"""When ``quota_reserve_micros_override`` is None we fall back to
|
||||
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
|
||||
Vision LLM calls take this path (token-priced models).
|
||||
"""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
captured_estimator_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_estimate(*, base_model, quota_reserve_tokens):
|
||||
captured_estimator_calls.append(
|
||||
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
|
||||
)
|
||||
return 12_345
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.estimate_call_reserve_micros",
|
||||
_fake_estimate,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
usage_type="vision_extraction",
|
||||
):
|
||||
pass
|
||||
|
||||
assert captured_estimator_calls == [
|
||||
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
|
||||
]
|
||||
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
|
||||
from app.services.billable_calls import BillingSettlementError, billable_call
|
||||
|
||||
class _FinalizeError(RuntimeError):
|
||||
pass
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(allowed=True),
|
||||
finalize_exc=_FinalizeError("db finalize failed"),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(BillingSettlementError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _HangingCommitSession(_FakeSession):
|
||||
async def commit(self) -> None:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _hanging_session_factory():
|
||||
s = _HangingCommitSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
billable_session_factory=_hanging_session_factory,
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["release"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_audit_failure_is_best_effort(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
async def _failing_record(_session, **_kwargs):
|
||||
raise RuntimeError("audit insert failed")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_failing_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
async with billable_call(
|
||||
user_id=uuid4(),
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
audit_timeout_seconds=0.01,
|
||||
) as acc:
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Podcast / video-presentation usage_type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
||||
"""Free podcast configs must skip reserve/finalize but still emit a
|
||||
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
|
||||
have full audit coverage of free-tier agent runs."""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openrouter/some-free-model",
|
||||
quota_reserve_micros_override=200_000,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=99,
|
||||
call_details={"podcast_id": 7, "title": "Test Podcast"},
|
||||
) as acc:
|
||||
# Two transcript LLM calls aggregated into one accumulator.
|
||||
acc.add(
|
||||
model="openrouter/some-free-model",
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=8000,
|
||||
total_tokens=9500,
|
||||
cost_micros=0,
|
||||
call_kind="chat",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
|
||||
assert len(spies["record"]) == 1
|
||||
row = spies["record"][0]
|
||||
assert row["usage_type"] == "podcast_generation"
|
||||
assert row["thread_id"] is None
|
||||
assert row["search_space_id"] == 42
|
||||
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
|
||||
"""Premium video-presentation runs that hit a denied reservation must
|
||||
raise ``QuotaInsufficientError`` *before* the graph runs and must not
|
||||
emit an audit row (no work happened)."""
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="gpt-5.4",
|
||||
quota_reserve_micros_override=1_000_000,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=99,
|
||||
call_details={"video_presentation_id": 12, "title": "Test Video"},
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "video_presentation_generation"
|
||||
assert err.remaining_micros == 500_000
|
||||
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
assert spies["record"] == []
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
"""Defense-in-depth: image-gen call sites must not let an empty
|
||||
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
|
||||
|
||||
The bug repro: an OpenRouter image-gen config ships
|
||||
``api_base=""``. The pre-fix call site in
|
||||
``image_generation_routes._execute_image_generation`` did
|
||||
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
|
||||
silently dropped the empty string. LiteLLM then fell back to
|
||||
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
|
||||
and OpenRouter's ``image_generation/transformation`` appended
|
||||
``/chat/completions`` to it → 404 ``Resource not found``.
|
||||
|
||||
This test pins the post-fix behaviour: with an empty ``api_base`` in
|
||||
the config, the call site MUST set ``api_base`` to OpenRouter's public
|
||||
URL instead of leaving it unset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||
"""The global-config branch (``config_id < 0``) of
|
||||
``_execute_image_generation`` must apply the resolver and pin
|
||||
``api_base`` to OpenRouter when the config ships an empty string.
|
||||
"""
|
||||
from app.routes import image_generation_routes
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "", # the original bug shape
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
|
||||
|
||||
image_gen = MagicMock()
|
||||
image_gen.image_generation_config_id = cfg["id"]
|
||||
image_gen.prompt = "test"
|
||||
image_gen.n = 1
|
||||
image_gen.quality = None
|
||||
image_gen.size = None
|
||||
image_gen.style = None
|
||||
image_gen.response_format = None
|
||||
image_gen.model = None
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
session = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"_get_global_image_gen_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch.object(
|
||||
image_generation_routes,
|
||||
"aimage_generation",
|
||||
side_effect=fake_aimage_generation,
|
||||
),
|
||||
):
|
||||
await image_generation_routes._execute_image_generation(
|
||||
session=session, image_gen=image_gen, search_space=search_space
|
||||
)
|
||||
|
||||
# The whole point of the fix: even with empty ``api_base`` in the
|
||||
# config, we forward OpenRouter's public URL so the call doesn't
|
||||
# inherit an Azure endpoint.
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.new_chat.tools import generate_image as gi_module
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
"name": "GPT Image 1 (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
async def fake_aimage_generation(**kwargs):
|
||||
captured.update(kwargs)
|
||||
response = MagicMock()
|
||||
response.model_dump.return_value = {
|
||||
"data": [{"url": "https://example.com/x.png"}]
|
||||
}
|
||||
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
|
||||
return response
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.image_generation_config_id = cfg["id"]
|
||||
|
||||
session_cm = AsyncMock()
|
||||
session = AsyncMock()
|
||||
session_cm.__aenter__.return_value = session
|
||||
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalars.return_value = scalars
|
||||
session.execute.return_value = exec_result
|
||||
session.add = MagicMock()
|
||||
session.commit = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
|
||||
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
|
||||
async def _refresh(obj):
|
||||
obj.id = 1
|
||||
|
||||
session.refresh.side_effect = _refresh
|
||||
|
||||
with (
|
||||
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
|
||||
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
|
||||
patch.object(
|
||||
gi_module, "aimage_generation", side_effect=fake_aimage_generation
|
||||
),
|
||||
patch.object(
|
||||
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
|
||||
),
|
||||
):
|
||||
tool = gi_module.create_generate_image_tool(
|
||||
search_space_id=1, db_session=MagicMock()
|
||||
)
|
||||
await tool.ainvoke({"prompt": "a cat", "n": 1})
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
||||
|
||||
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
|
||||
"""The Auto-mode router pool must also resolve ``api_base`` when an
|
||||
OpenRouter config ships an empty string. The deployment dict is fed
|
||||
straight to ``litellm.Router``, so a missing ``api_base`` would
|
||||
leak the same way as the direct call sites.
|
||||
"""
|
||||
from app.services.image_gen_router_service import ImageGenRouterService
|
||||
|
||||
deployment = ImageGenRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-image-1",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
"""LLMRouterService pool-filter / rebuild tests.
|
||||
|
||||
These tests focus on the *config plumbing* (which configs enter the router
|
||||
pool, rebuild resets state correctly). They stub out the underlying
|
||||
``litellm.Router`` so we don't need real API keys or network access.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.llm_router_service import LLMRouterService
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _fake_yaml_config(
|
||||
*,
|
||||
id: int,
|
||||
model_name: str,
|
||||
billing_tier: str = "free",
|
||||
) -> dict:
|
||||
return {
|
||||
"id": id,
|
||||
"name": f"yaml-{id}",
|
||||
"provider": "OPENAI",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-test",
|
||||
"api_base": "",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 100,
|
||||
"tpm": 100_000,
|
||||
"litellm_params": {},
|
||||
}
|
||||
|
||||
|
||||
def _fake_openrouter_config(
|
||||
*,
|
||||
id: int,
|
||||
model_name: str,
|
||||
billing_tier: str,
|
||||
router_pool_eligible: bool | None = None,
|
||||
) -> dict:
|
||||
"""Build a synthetic dynamic-OR config dict for router-pool tests.
|
||||
|
||||
Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
|
||||
out. Callers can override ``router_pool_eligible`` to simulate legacy
|
||||
configs or to regression-test the filter mechanics directly.
|
||||
"""
|
||||
if router_pool_eligible is None:
|
||||
router_pool_eligible = billing_tier == "premium"
|
||||
return {
|
||||
"id": id,
|
||||
"name": f"or-{id}",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_name,
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"billing_tier": billing_tier,
|
||||
"rpm": 20 if billing_tier == "free" else 200,
|
||||
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
|
||||
"litellm_params": {},
|
||||
"router_pool_eligible": router_pool_eligible,
|
||||
}
|
||||
|
||||
|
||||
def _reset_router_singleton() -> None:
|
||||
instance = LLMRouterService.get_instance()
|
||||
instance._initialized = False
|
||||
instance._router = None
|
||||
instance._model_list = []
|
||||
instance._premium_model_strings = set()
|
||||
|
||||
|
||||
def test_router_pool_includes_or_premium_excludes_or_free():
|
||||
"""Strategy 3: premium OR joins the pool, free OR stays out.
|
||||
|
||||
Dynamic OpenRouter premium entries opt into load balancing alongside
|
||||
curated YAML configs. Dynamic OR free entries are intentionally kept
|
||||
out because OpenRouter's free tier enforces a single account-global
|
||||
quota bucket that per-deployment router accounting can't represent.
|
||||
"""
|
||||
_reset_router_singleton()
|
||||
configs = [
|
||||
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
|
||||
_fake_openrouter_config(
|
||||
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
|
||||
),
|
||||
_fake_openrouter_config(
|
||||
id=-10_002,
|
||||
model_name="meta-llama/llama-3.3-70b:free",
|
||||
billing_tier="free",
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.services.llm_router_service.Router") as mock_router,
|
||||
patch(
|
||||
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||
) as mock_ctx_fb,
|
||||
):
|
||||
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||
mock_router.return_value = object()
|
||||
LLMRouterService.initialize(configs)
|
||||
|
||||
pool_models = {
|
||||
dep["litellm_params"]["model"]
|
||||
for dep in LLMRouterService.get_instance()._model_list
|
||||
}
|
||||
# YAML premium + YAML free + dynamic OR premium are all in the pool.
|
||||
# Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
|
||||
assert pool_models == {
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openrouter/openai/gpt-4o",
|
||||
}
|
||||
|
||||
prem = LLMRouterService.get_instance()._premium_model_strings
|
||||
# YAML premium is fingerprinted under both its model_string and its
|
||||
# ``base_model`` form (existing behavior we don't want to regress).
|
||||
assert "openai/gpt-4o" in prem
|
||||
# Dynamic OR premium is now fingerprinted as premium so pool-level
|
||||
# calls through the router are billed against premium quota.
|
||||
assert "openrouter/openai/gpt-4o" in prem
|
||||
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
|
||||
# Dynamic OR free never enters the pool, so it's never counted as premium.
|
||||
assert (
|
||||
LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_router_pool_filter_mechanics_respect_override():
|
||||
"""The ``router_pool_eligible`` filter itself works independently of tier.
|
||||
|
||||
Regression guard: if a future refactor ever sets the flag False on a
|
||||
premium config (e.g. for maintenance), that config MUST be skipped by
|
||||
``initialize`` even though its tier is premium.
|
||||
"""
|
||||
_reset_router_singleton()
|
||||
configs = [
|
||||
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||
_fake_openrouter_config(
|
||||
id=-10_001,
|
||||
model_name="openai/gpt-4o",
|
||||
billing_tier="premium",
|
||||
router_pool_eligible=False, # opt out despite being premium
|
||||
),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.services.llm_router_service.Router") as mock_router,
|
||||
patch(
|
||||
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||
) as mock_ctx_fb,
|
||||
):
|
||||
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||
mock_router.return_value = object()
|
||||
LLMRouterService.initialize(configs)
|
||||
|
||||
pool_models = {
|
||||
dep["litellm_params"]["model"]
|
||||
for dep in LLMRouterService.get_instance()._model_list
|
||||
}
|
||||
assert pool_models == {"openai/gpt-4o"}
|
||||
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
|
||||
|
||||
|
||||
def test_rebuild_refreshes_pool_after_configs_change():
|
||||
_reset_router_singleton()
|
||||
configs_v1 = [
|
||||
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||
]
|
||||
configs_v2 = [
|
||||
*configs_v1,
|
||||
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("app.services.llm_router_service.Router") as mock_router,
|
||||
patch(
|
||||
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||
) as mock_ctx_fb,
|
||||
):
|
||||
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||
mock_router.return_value = object()
|
||||
|
||||
LLMRouterService.initialize(configs_v1)
|
||||
assert len(LLMRouterService.get_instance()._model_list) == 1
|
||||
|
||||
# ``initialize`` should be a no-op here (already initialized).
|
||||
LLMRouterService.initialize(configs_v2)
|
||||
assert len(LLMRouterService.get_instance()._model_list) == 1
|
||||
|
||||
# ``rebuild`` must clear the guard and re-run with the new configs.
|
||||
LLMRouterService.rebuild(configs_v2)
|
||||
assert len(LLMRouterService.get_instance()._model_list) == 2
|
||||
|
||||
|
||||
def test_auto_model_pin_candidates_include_dynamic_openrouter():
|
||||
"""Dynamic OR configs must remain Auto-mode thread-pin candidates.
|
||||
|
||||
Guards against a future regression where someone adds the
|
||||
``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.auto_model_pin_service import _global_candidates
|
||||
|
||||
or_premium = _fake_openrouter_config(
|
||||
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
|
||||
)
|
||||
or_free = _fake_openrouter_config(
|
||||
id=-10_002,
|
||||
model_name="meta-llama/llama-3.3-70b:free",
|
||||
billing_tier="free",
|
||||
)
|
||||
original = config.GLOBAL_LLM_CONFIGS
|
||||
try:
|
||||
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
|
||||
candidate_ids = {c["id"] for c in _global_candidates()}
|
||||
assert candidate_ids == {-10_001, -10_002}
|
||||
finally:
|
||||
config.GLOBAL_LLM_CONFIGS = original
|
||||
|
|
@ -0,0 +1,380 @@
|
|||
"""Unit tests for the dynamic OpenRouter integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
_generate_configs,
|
||||
_openrouter_tier,
|
||||
_stable_config_id,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _minimal_openrouter_model(
|
||||
*,
|
||||
model_id: str,
|
||||
pricing: dict | None = None,
|
||||
name: str | None = None,
|
||||
) -> dict:
|
||||
"""Return a synthetic OpenRouter /api/v1/models entry.
|
||||
|
||||
The real API payload includes a lot of fields; we only populate what
|
||||
``_generate_configs`` actually inspects (architecture, tool support,
|
||||
context, pricing, id).
|
||||
"""
|
||||
return {
|
||||
"id": model_id,
|
||||
"name": name or model_id,
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _openrouter_tier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_openrouter_tier_free_suffix():
|
||||
assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
|
||||
|
||||
|
||||
def test_openrouter_tier_zero_pricing():
|
||||
model = {
|
||||
"id": "foo/bar",
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
}
|
||||
assert _openrouter_tier(model) == "free"
|
||||
|
||||
|
||||
def test_openrouter_tier_paid():
|
||||
model = {
|
||||
"id": "foo/bar",
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
}
|
||||
assert _openrouter_tier(model) == "premium"
|
||||
|
||||
|
||||
def test_openrouter_tier_missing_pricing_is_premium():
|
||||
assert _openrouter_tier({"id": "foo/bar"}) == "premium"
|
||||
assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _stable_config_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stable_config_id_deterministic():
|
||||
taken1: set[int] = set()
|
||||
taken2: set[int] = set()
|
||||
a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
|
||||
b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
|
||||
assert a == b
|
||||
assert a < 0
|
||||
|
||||
|
||||
def test_stable_config_id_collision_decrements():
|
||||
"""When two model_ids hash to the same slot, the second should decrement."""
|
||||
taken: set[int] = set()
|
||||
a = _stable_config_id("openai/gpt-4o", -10_000, taken)
|
||||
# Force a collision by pre-populating ``taken`` with a slot we know will be
|
||||
# picked.
|
||||
taken_forced = {a}
|
||||
b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
|
||||
assert b != a
|
||||
assert b == a - 1
|
||||
assert b in taken_forced
|
||||
|
||||
|
||||
def test_stable_config_id_different_models_different_ids():
|
||||
taken: set[int] = set()
|
||||
ids = {
|
||||
_stable_config_id("openai/gpt-4o", -10_000, taken),
|
||||
_stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
|
||||
_stable_config_id("google/gemini-2.0-flash", -10_000, taken),
|
||||
}
|
||||
assert len(ids) == 3
|
||||
|
||||
|
||||
def test_stable_config_id_survives_catalogue_churn():
|
||||
"""Removing a model should not shift other models' IDs (the bug we fix)."""
|
||||
taken1: set[int] = set()
|
||||
id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
|
||||
_ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
|
||||
id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
|
||||
|
||||
taken2: set[int] = set()
|
||||
id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
|
||||
id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
|
||||
|
||||
assert id_a1 == id_a2
|
||||
assert id_c1 == id_c2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_SETTINGS_BASE: dict = {
|
||||
"api_key": "sk-or-test",
|
||||
"id_offset": -10_000,
|
||||
"rpm": 200,
|
||||
"tpm": 1_000_000,
|
||||
"free_rpm": 20,
|
||||
"free_tpm": 100_000,
|
||||
"anonymous_enabled_paid": False,
|
||||
"anonymous_enabled_free": True,
|
||||
"quota_reserve_tokens": 4000,
|
||||
}
|
||||
|
||||
|
||||
def test_generate_configs_respects_tier():
|
||||
"""Premium OR models opt into the router pool; free OR models stay out.
|
||||
|
||||
Strategy-3 split: premium participates in LiteLLM Router load balancing,
|
||||
free stays excluded because OpenRouter enforces a shared global free-tier
|
||||
bucket that per-deployment router accounting can't represent.
|
||||
"""
|
||||
raw = [
|
||||
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||
_minimal_openrouter_model(
|
||||
model_id="meta-llama/llama-3.3-70b-instruct:free",
|
||||
pricing={"prompt": "0", "completion": "0"},
|
||||
),
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
by_model = {c["model_name"]: c for c in cfgs}
|
||||
|
||||
paid = by_model["openai/gpt-4o"]
|
||||
assert paid["billing_tier"] == "premium"
|
||||
assert paid["rpm"] == 200
|
||||
assert paid["tpm"] == 1_000_000
|
||||
assert paid["anonymous_enabled"] is False
|
||||
assert paid["router_pool_eligible"] is True
|
||||
assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
|
||||
assert free["billing_tier"] == "free"
|
||||
assert free["rpm"] == 20
|
||||
assert free["tpm"] == 100_000
|
||||
assert free["anonymous_enabled"] is True
|
||||
assert free["router_pool_eligible"] is False
|
||||
|
||||
|
||||
def test_generate_configs_excludes_upstream_openrouter_free_router():
|
||||
"""OpenRouter's own ``openrouter/free`` meta-router must never become a card.
|
||||
|
||||
The upstream API returns this as a first-class zero-priced model, so
|
||||
without an explicit blocklist entry it would slip through every other
|
||||
filter (text output, tool calling, 200k context, non-Amazon) and land
|
||||
in the selector as a duplicate of the concrete ``:free`` cards. The
|
||||
exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
|
||||
"""
|
||||
raw = [
|
||||
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||
_minimal_openrouter_model(
|
||||
model_id="openrouter/free",
|
||||
pricing={"prompt": "0", "completion": "0"},
|
||||
),
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = {c["model_name"] for c in cfgs}
|
||||
assert "openrouter/free" not in model_names
|
||||
assert "openai/gpt-4o" in model_names
|
||||
|
||||
|
||||
def test_generate_configs_drops_non_text_and_non_tool_models():
|
||||
raw = [
|
||||
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||
{ # image-output model
|
||||
"id": "openai/dall-e",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.01", "completion": "0.01"},
|
||||
},
|
||||
{ # text but no tool calling
|
||||
"id": "openai/completion-only",
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"supported_parameters": [],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.01", "completion": "0.01"},
|
||||
},
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = [c["model_name"] for c in cfgs]
|
||||
assert "openai/gpt-4o" in model_names
|
||||
assert "openai/dall-e" not in model_names
|
||||
assert "openai/completion-only" not in model_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_filters_by_image_output():
|
||||
"""Only models with ``output_modalities`` containing ``image`` are emitted.
|
||||
Tool-calling and context filters are intentionally NOT applied — image
|
||||
generation has nothing to do with tool calls and context windows.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# Pure image-gen model (small context, no tools — should still emit).
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Multi-modal: text+image output (should still emit).
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"architecture": {"output_modalities": ["text", "image"]},
|
||||
"context_length": 1_000_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
|
||||
},
|
||||
# Pure text model — must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = {c["model_name"] for c in cfgs}
|
||||
assert "openai/gpt-image-1" in model_names
|
||||
assert "google/gemini-2.5-flash-image" in model_names
|
||||
assert "openai/gpt-4o" not in model_names
|
||||
|
||||
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||
for c in cfgs:
|
||||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't 404 against an inherited Azure endpoint.
|
||||
assert c["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
"""Image configs use a different id_offset (-20000) so their negative
|
||||
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
}
|
||||
]
|
||||
# Don't pass image_id_offset → use the module default (-20000).
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||
(no text out) and text-only (no image in) models are excluded.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["image"],
|
||||
},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Pure text model (no image in). Must NOT emit.
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
names = {c["model_name"] for c in cfgs}
|
||||
assert names == {"openai/gpt-4o"}
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["billing_tier"] == "premium"
|
||||
# Pricing carried inline so pricing_registration can register vision
|
||||
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||
# is cleared.
|
||||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
# Defense-in-depth: emit the OpenRouter base URL at source so a
|
||||
# downstream call site that forgets ``resolve_api_base`` still
|
||||
# doesn't inherit an Azure endpoint.
|
||||
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
"""A small-context vision model that doesn't advertise tool calling is
|
||||
still a valid vision LLM for "describe this image" prompts. The chat
|
||||
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||
NOT be applied to vision emission.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "tiny/vision-mini",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": [], # no tools
|
||||
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||
}
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert len(cfgs) == 1
|
||||
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
"""Tests for deprecated-key warnings and back-compat in
|
||||
``load_openrouter_integration_settings``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _write_yaml(tmp_path: Path, body: str) -> Path:
|
||||
cfg_dir = tmp_path / "app" / "config"
|
||||
cfg_dir.mkdir(parents=True)
|
||||
cfg_path = cfg_dir / "global_llm_config.yaml"
|
||||
cfg_path.write_text(body, encoding="utf-8")
|
||||
return cfg_path
|
||||
|
||||
|
||||
def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
|
||||
def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: true
|
||||
api_key: "sk-or-test"
|
||||
billing_tier: "premium"
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
settings = load_openrouter_integration_settings()
|
||||
captured = capsys.readouterr().out
|
||||
assert settings is not None
|
||||
assert "billing_tier is deprecated" in captured
|
||||
|
||||
|
||||
def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: true
|
||||
api_key: "sk-or-test"
|
||||
anonymous_enabled: true
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
settings = load_openrouter_integration_settings()
|
||||
captured = capsys.readouterr().out
|
||||
assert settings is not None
|
||||
assert settings["anonymous_enabled_paid"] is True
|
||||
assert settings["anonymous_enabled_free"] is True
|
||||
assert "anonymous_enabled is" in captured
|
||||
assert "deprecated" in captured
|
||||
|
||||
|
||||
def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
|
||||
"""If both legacy and new keys are present, new keys win (setdefault)."""
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: true
|
||||
api_key: "sk-or-test"
|
||||
anonymous_enabled: true
|
||||
anonymous_enabled_paid: false
|
||||
anonymous_enabled_free: false
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
settings = load_openrouter_integration_settings()
|
||||
capsys.readouterr()
|
||||
assert settings is not None
|
||||
assert settings["anonymous_enabled_paid"] is False
|
||||
assert settings["anonymous_enabled_free"] is False
|
||||
|
||||
|
||||
def test_disabled_integration_returns_none(monkeypatch, tmp_path):
|
||||
_write_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
openrouter_integration:
|
||||
enabled: false
|
||||
api_key: "sk-or-test"
|
||||
""".lstrip(),
|
||||
)
|
||||
_patch_base_dir(monkeypatch, tmp_path)
|
||||
|
||||
from app.config import load_openrouter_integration_settings
|
||||
|
||||
assert load_openrouter_integration_settings() is None
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
OpenRouterIntegrationService,
|
||||
)
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_FAIL_RATIO_FALLBACK,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _or_cfg(
|
||||
*,
|
||||
cid: int,
|
||||
model_name: str,
|
||||
tier: str = "premium",
|
||||
static_score: int = 50,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": cid,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": model_name,
|
||||
"billing_tier": tier,
|
||||
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||
"quality_score_static": static_score,
|
||||
"quality_score_health": None,
|
||||
"quality_score": static_score,
|
||||
"health_gated": False,
|
||||
}
|
||||
|
||||
|
||||
class _StubResponse:
|
||||
def __init__(self, *, payload: dict, status_code: int = 200):
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self.status_code >= 400:
|
||||
raise RuntimeError(f"HTTP {self.status_code}")
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
|
||||
class _StubAsyncClient:
|
||||
"""Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
|
||||
|
||||
def __init__(self, responder):
|
||||
self._responder = responder
|
||||
self.requests: list[str] = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
|
||||
self.requests.append(url)
|
||||
return self._responder(url)
|
||||
|
||||
|
||||
def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
|
||||
"""Replace ``httpx.AsyncClient`` for the duration of the test."""
|
||||
client = _StubAsyncClient(responder)
|
||||
monkeypatch.setattr(
|
||||
"app.services.openrouter_integration_service.httpx.AsyncClient",
|
||||
lambda *_args, **_kwargs: client,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def _healthy_payload() -> dict:
|
||||
return {
|
||||
"data": {
|
||||
"endpoints": [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.99,
|
||||
"uptime_last_1d": 0.995,
|
||||
"uptime_last_5m": 0.99,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _unhealthy_payload() -> dict:
|
||||
return {
|
||||
"data": {
|
||||
"endpoints": [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.55,
|
||||
"uptime_last_1d": 0.62,
|
||||
"uptime_last_5m": 0.50,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded fan-out + happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
|
||||
cfgs = [
|
||||
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||
_or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
|
||||
]
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
if "anthropic" in url:
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
return _StubResponse(payload=_unhealthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {"api_key": ""}
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
healthy = next(c for c in cfgs if c["id"] == -1)
|
||||
gated = next(c for c in cfgs if c["id"] == -2)
|
||||
|
||||
assert healthy["health_gated"] is False
|
||||
assert healthy["quality_score_health"] is not None
|
||||
assert healthy["quality_score"] >= healthy["quality_score_static"]
|
||||
|
||||
assert gated["health_gated"] is True
|
||||
assert gated["quality_score"] == gated["quality_score_static"]
|
||||
|
||||
|
||||
async def test_enrich_health_only_touches_or_provider(monkeypatch):
|
||||
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
|
||||
yaml_cfg = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
"auto_pin_tier": "A",
|
||||
"quality_score_static": 80,
|
||||
"quality_score": 80,
|
||||
"health_gated": False,
|
||||
}
|
||||
or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
|
||||
|
||||
requests: list[str] = []
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
requests.append(url)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health([yaml_cfg, or_cfg])
|
||||
|
||||
assert all("anthropic/claude-haiku" in r for r in requests)
|
||||
# YAML cfg is untouched.
|
||||
assert yaml_cfg["quality_score"] == 80
|
||||
assert yaml_cfg["health_gated"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure ratio fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
|
||||
monkeypatch,
|
||||
):
|
||||
"""If >= 25% of fetches fail, keep last-good cache instead of writing
|
||||
partial data."""
|
||||
cfgs = [
|
||||
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||
_or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
|
||||
_or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
|
||||
_or_cfg(cid=-4, model_name="venice/something", static_score=50),
|
||||
]
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
# Pre-seed last-good cache with a known-healthy snapshot.
|
||||
service._health_cache = {
|
||||
"anthropic/claude-haiku": {"gated": False, "score": 95.0},
|
||||
}
|
||||
|
||||
def all_fail(_url: str) -> _StubResponse:
|
||||
return _StubResponse(payload={}, status_code=500)
|
||||
|
||||
_patch_async_client(monkeypatch, all_fail)
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
# Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
|
||||
cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
|
||||
assert cached_hit["quality_score_health"] == 95.0
|
||||
assert cached_hit["health_gated"] is False
|
||||
# Confirm the threshold constant we're testing against is real.
|
||||
assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
|
||||
|
||||
|
||||
async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
|
||||
monkeypatch,
|
||||
):
|
||||
"""If a fetch fails and there's no last-good cache, the cfg keeps its
|
||||
static-only ``quality_score`` and is *not* gated by default."""
|
||||
cfgs = [
|
||||
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||
]
|
||||
|
||||
def fail(_url: str) -> _StubResponse:
|
||||
return _StubResponse(payload={}, status_code=500)
|
||||
|
||||
_patch_async_client(monkeypatch, fail)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["health_gated"] is False
|
||||
assert cfg["quality_score"] == cfg["quality_score_static"]
|
||||
assert cfg["quality_score_health"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Last-good cache: success populates, next failure reuses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
|
||||
monkeypatch,
|
||||
):
|
||||
cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
|
||||
def healthy(_url: str) -> _StubResponse:
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, healthy)
|
||||
await service._enrich_health([cfg])
|
||||
|
||||
assert "anthropic/claude-haiku" in service._health_cache
|
||||
cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
|
||||
assert cached_score is not None
|
||||
|
||||
# Next cycle: enough other healthy cfgs so failure ratio stays below
|
||||
# the 25% threshold even when this one fails individually.
|
||||
other_cfgs = [
|
||||
_or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
|
||||
for i in range(10)
|
||||
]
|
||||
cfg["quality_score_health"] = None
|
||||
cfg["quality_score"] = cfg["quality_score_static"]
|
||||
|
||||
def mixed(url: str) -> _StubResponse:
|
||||
if "anthropic" in url:
|
||||
return _StubResponse(payload={}, status_code=500)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, mixed)
|
||||
await service._enrich_health([cfg, *other_cfgs])
|
||||
|
||||
assert cfg["quality_score_health"] == cached_score
|
||||
assert cfg["health_gated"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bounded fan-out: respects top-N caps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def test_enrich_health_bounds_premium_fanout(monkeypatch):
|
||||
"""Top-N premium cap is honoured even when many cfgs are present."""
|
||||
from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
|
||||
|
||||
cfgs = [
|
||||
_or_cfg(
|
||||
cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
|
||||
)
|
||||
for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
|
||||
]
|
||||
|
||||
seen: list[str] = []
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
seen.append(url)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health(cfgs)
|
||||
|
||||
assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
|
||||
|
||||
|
||||
async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
|
||||
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
|
||||
yaml_cfg: dict[str, Any] = {
|
||||
"id": -1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"billing_tier": "premium",
|
||||
}
|
||||
requests: list[str] = []
|
||||
|
||||
def responder(url: str) -> _StubResponse:
|
||||
requests.append(url)
|
||||
return _StubResponse(payload=_healthy_payload())
|
||||
|
||||
_patch_async_client(monkeypatch, responder)
|
||||
|
||||
service = OpenRouterIntegrationService()
|
||||
service._settings = {}
|
||||
await service._enrich_health([yaml_cfg])
|
||||
assert requests == []
|
||||
|
|
@ -0,0 +1,447 @@
|
|||
"""Pricing registration unit tests.
|
||||
|
||||
The pricing-registration module is what makes ``response_cost`` populate
|
||||
correctly for OpenRouter dynamic models and operator-defined Azure
|
||||
deployments — both of which LiteLLM doesn't natively know about. The tests
|
||||
exercise:
|
||||
|
||||
* The alias generators emit every shape that LiteLLM's cost-callback might
|
||||
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
|
||||
``provider/base_model``, ``provider/model_name``, plus the special
|
||||
``azure_openai`` → ``azure`` normalisation).
|
||||
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
|
||||
with the right alias set and pricing values per provider.
|
||||
* Configs without a resolvable pair of cost values are skipped — never
|
||||
registered as zero, since that would override pricing LiteLLM might
|
||||
already know natively.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_openrouter_alias_set_includes_prefixed_and_bare():
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
|
||||
assert aliases == [
|
||||
"openrouter/anthropic/claude-3-5-sonnet",
|
||||
"anthropic/claude-3-5-sonnet",
|
||||
]
|
||||
|
||||
|
||||
def test_openrouter_alias_set_dedupes():
|
||||
"""If the model id is already prefixed with ``openrouter/``, the alias
|
||||
set must not contain duplicates that would re-register the same key
|
||||
twice.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("openrouter/foo")
|
||||
# The bare and prefixed variants compute to the same string here, so we
|
||||
# at minimum require uniqueness.
|
||||
assert len(aliases) == len(set(aliases))
|
||||
|
||||
|
||||
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
|
||||
"""``azure_openai`` (our YAML provider slug) must register under
|
||||
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
|
||||
(which uses provider ``azure``) finds the pricing too.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="gpt-5.4",
|
||||
base_model="gpt-5.4",
|
||||
)
|
||||
assert "gpt-5.4" in aliases
|
||||
assert "azure_openai/gpt-5.4" in aliases
|
||||
assert "azure/gpt-5.4" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
|
||||
"""When ``model_name`` differs from ``base_model`` (operator labelled a
|
||||
deployment), both must appear in the alias set since either may surface
|
||||
in callbacks depending on the call path.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="OPENAI",
|
||||
model_name="my-deployment-label",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
assert "gpt-4o" in aliases
|
||||
assert "openai/gpt-4o" in aliases
|
||||
assert "my-deployment-label" in aliases
|
||||
assert "openai/my-deployment-label" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="",
|
||||
model_name="foo",
|
||||
base_model="bar",
|
||||
)
|
||||
assert "bar" in aliases
|
||||
assert "foo" in aliases
|
||||
assert all("/" not in a for a in aliases)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_pricing_from_global_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RegistrationSpy:
|
||||
"""Captures the dicts passed to ``litellm.register_model``.
|
||||
|
||||
Many calls may go through; we just record them all and let tests assert
|
||||
against the union.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
def __call__(self, payload: dict[str, Any]) -> None:
|
||||
self.calls.append(payload)
|
||||
|
||||
@property
|
||||
def all_keys(self) -> set[str]:
|
||||
keys: set[str] = set()
|
||||
for payload in self.calls:
|
||||
keys.update(payload.keys())
|
||||
return keys
|
||||
|
||||
|
||||
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
|
||||
spy = _RegistrationSpy()
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
spy,
|
||||
raising=False,
|
||||
)
|
||||
return spy
|
||||
|
||||
|
||||
def _patch_openrouter_pricing(
|
||||
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
|
||||
) -> None:
|
||||
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
|
||||
|
||||
class _Stub:
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
return mapping
|
||||
|
||||
class _StubService:
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> _Stub:
|
||||
return _Stub()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
|
||||
_StubService,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_models_register_under_aliases(monkeypatch):
|
||||
"""An OpenRouter config whose ``model_name`` is in the cached raw
|
||||
pricing map is registered under both ``openrouter/X`` and bare ``X``.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
# Costs are float-converted from the raw OpenRouter strings.
|
||||
payload = spy.calls[0]
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"input_cost_per_token"
|
||||
] == pytest.approx(3e-6)
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"output_cost_per_token"
|
||||
] == pytest.approx(15e-6)
|
||||
assert (
|
||||
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
|
||||
== "openrouter"
|
||||
)
|
||||
|
||||
|
||||
def test_yaml_override_registers_under_alias_set(monkeypatch):
|
||||
"""Operator-declared ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` on a YAML config registers under every
|
||||
alias the YAML alias generator produces — including the ``azure/``
|
||||
normalisation for ``azure_openai`` providers.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5.4",
|
||||
"litellm_params": {
|
||||
"base_model": "gpt-5.4",
|
||||
"input_cost_per_token": 2e-6,
|
||||
"output_cost_per_token": 8e-6,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
keys = spy.all_keys
|
||||
assert "gpt-5.4" in keys
|
||||
assert "azure_openai/gpt-5.4" in keys
|
||||
assert "azure/gpt-5.4" in keys
|
||||
|
||||
payload = spy.calls[0]
|
||||
entry = payload["gpt-5.4"]
|
||||
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
|
||||
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
|
||||
assert entry["litellm_provider"] == "azure"
|
||||
|
||||
|
||||
def test_no_override_means_no_registration(monkeypatch):
|
||||
"""A YAML config that *omits* both pricing fields must NOT be registered
|
||||
— registering as zero would override LiteLLM's native pricing for the
|
||||
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
|
||||
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {"base_model": "gpt-4o"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
|
||||
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
|
||||
configured model (network blip during refresh, model added later, etc.),
|
||||
we skip it rather than registering zero pricing.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
||||
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
|
||||
must not abort registration of the remaining configs.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
|
||||
successful_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _maybe_fail(payload: dict[str, Any]) -> None:
|
||||
if any(k in failing_keys for k in payload):
|
||||
raise RuntimeError("boom")
|
||||
successful_calls.append(payload)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
_maybe_fail,
|
||||
raising=False,
|
||||
)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "custom-deployment",
|
||||
"litellm_params": {
|
||||
"base_model": "custom-deployment",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 2e-6,
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
# The good config still registered.
|
||||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||
"""``register_pricing_from_global_configs`` walks
|
||||
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||
calls (during indexing) bill correctly. Vision configs use the same
|
||||
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||
registered here (handled via ``response_cost`` in LiteLLM).
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||
)
|
||||
|
||||
# No chat configs — only vision. Proves the vision walk is a separate
|
||||
# iteration, not piggy-backed on the chat list.
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||
assert payload_value["mode"] == "chat"
|
||||
assert payload_value["litellm_provider"] == "openrouter"
|
||||
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
|
||||
|
||||
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||
"""If the OpenRouter pricing cache misses a vision model (different
|
||||
catalogue surface), the vision walk falls back to inline
|
||||
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 4e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
107
surfsense_backend/tests/unit/services/test_provider_api_base.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Unit tests for the shared ``api_base`` resolver.
|
||||
|
||||
The cascade exists so vision and image-gen call sites can't silently
|
||||
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
|
||||
when an OpenRouter / Groq / etc. config ships an empty string. See
|
||||
``provider_api_base`` module docstring for the original repro
|
||||
(OpenRouter image-gen 404-ing against an Azure endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_api_base import (
|
||||
PROVIDER_DEFAULT_API_BASE,
|
||||
PROVIDER_KEY_DEFAULT_API_BASE,
|
||||
resolve_api_base,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_config_value_wins_over_defaults():
|
||||
"""A non-empty config value is always returned verbatim, even when the
|
||||
provider has a default — the operator gets the last word."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="https://my-openrouter-mirror.example.com/v1",
|
||||
)
|
||||
assert result == "https://my-openrouter-mirror.example.com/v1"
|
||||
|
||||
|
||||
def test_provider_key_default_when_config_missing():
|
||||
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
|
||||
base URL — the provider-key map must take precedence over the prefix
|
||||
map so DeepSeek requests don't go to OpenAI."""
|
||||
result = resolve_api_base(
|
||||
provider="DEEPSEEK",
|
||||
provider_prefix="openai",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_provider_prefix_default_when_no_key_default():
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_unknown_provider_returns_none():
|
||||
"""When neither map matches we return ``None`` so the caller can let
|
||||
LiteLLM apply its own provider-integration default (Azure deployment
|
||||
URL, custom-provider URL, etc.)."""
|
||||
result = resolve_api_base(
|
||||
provider="SOMETHING_NEW",
|
||||
provider_prefix="something_new",
|
||||
config_api_base=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_empty_string_config_treated_as_missing():
|
||||
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
|
||||
and downstream call sites use ``if cfg.get("api_base"):`` — empty
|
||||
strings are falsy in Python but the cascade has to step in anyway."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base="",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_whitespace_only_config_treated_as_missing():
|
||||
"""A config value of ``" "`` is a configuration mistake — treat it
|
||||
as missing instead of forwarding whitespace to LiteLLM (which would
|
||||
almost certainly 404)."""
|
||||
result = resolve_api_base(
|
||||
provider="OPENROUTER",
|
||||
provider_prefix="openrouter",
|
||||
config_api_base=" ",
|
||||
)
|
||||
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
|
||||
|
||||
|
||||
def test_provider_case_insensitive():
|
||||
"""Some call sites pass the provider lowercase (DB enum value), others
|
||||
uppercase (YAML key). Both must resolve."""
|
||||
upper = resolve_api_base(
|
||||
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
lower = resolve_api_base(
|
||||
provider="deepseek", provider_prefix="openai", config_api_base=None
|
||||
)
|
||||
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
|
||||
|
||||
|
||||
def test_all_inputs_none_returns_none():
|
||||
assert (
|
||||
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
|
||||
is None
|
||||
)
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Unit tests for the shared chat-image capability resolver.
|
||||
|
||||
Two resolvers, two intents:
|
||||
|
||||
- ``derive_supports_image_input`` — best-effort True for the catalog and
|
||||
selector. Default-allow on unknown / unmapped models. The streaming
|
||||
task safety net never sees this value directly.
|
||||
|
||||
- ``is_known_text_only_chat_model`` — strict opt-out for the safety net.
|
||||
Returns True only when LiteLLM's model map *explicitly* sets
|
||||
``supports_vision=False``. Anything else (missing key, exception,
|
||||
True) returns False so the request flows through to the provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import (
|
||||
derive_supports_image_input,
|
||||
is_known_text_only_chat_model,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — OpenRouter modalities path (authoritative)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_or_modalities_with_image_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="openai/gpt-4o",
|
||||
openrouter_input_modalities=["text", "image"],
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_text_only_returns_false():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="deepseek/deepseek-v3.2-exp",
|
||||
openrouter_input_modalities=["text"],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_empty_list_returns_false():
|
||||
"""OR explicitly publishing an empty modality list is a definitive
|
||||
'no inputs at all' signal — treat as False rather than falling back
|
||||
to LiteLLM."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENROUTER",
|
||||
model_name="weird/empty-modalities",
|
||||
openrouter_input_modalities=[],
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_or_modalities_none_falls_through_to_litellm():
|
||||
"""``None`` (missing key) is *not* a definitive signal — fall through
|
||||
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
openrouter_input_modalities=None,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# derive_supports_image_input — LiteLLM model-map path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_litellm_known_vision_model_returns_true():
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_base_model_wins_over_model_name():
|
||||
"""Azure-style entries pass model_name=deployment_id and put the
|
||||
canonical sku in litellm_params.base_model. The resolver must
|
||||
consult base_model first or the deployment id (which LiteLLM
|
||||
doesn't know) would shadow the real capability."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment-id",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_unknown_model_default_allows():
|
||||
"""Default-allow on unknown — the safety net is the actual block."""
|
||||
assert (
|
||||
derive_supports_image_input(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_known_text_only_returns_false():
|
||||
"""A model that LiteLLM explicitly knows is text-only resolves to
|
||||
False even via the catalog resolver. ``deepseek-chat`` (the
|
||||
DeepSeek-V3 chat sku) is in the map without supports_vision and
|
||||
LiteLLM's `supports_vision` returns False."""
|
||||
# Sanity: confirm the helper's negative path. We use a small model
|
||||
# known not to support vision per the map.
|
||||
result = derive_supports_image_input(
|
||||
provider="DEEPSEEK",
|
||||
model_name="deepseek-chat",
|
||||
)
|
||||
# We accept either False (LiteLLM said explicit no) or True
|
||||
# (default-allow if the entry isn't mapped on this version) — the
|
||||
# invariant is that the resolver never *raises* on a known-text-only
|
||||
# provider/model. The behaviour-binding assertion lives in
|
||||
# ``test_is_known_text_only_chat_model_explicit_false`` below.
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_known_text_only_chat_model — strict opt-out semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_vision_model():
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_for_unknown_model():
|
||||
"""Strict opt-out: missing from the map ≠ text-only. The safety net
|
||||
must NOT fire for an unmapped model — that's the regression we're
|
||||
fixing."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
model_name="brand-new-model-x9-unmapped",
|
||||
custom_provider="brand_new_proxy",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
|
||||
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
|
||||
helper swallows the exception and returns False so the safety net
|
||||
doesn't fire on a transient lookup failure."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise ValueError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
|
||||
we exercise the opt-out path deterministically. Using a stub keeps
|
||||
the test stable across LiteLLM map updates."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
|
||||
"""A model entry without ``supports_vision`` at all is treated as
|
||||
'unknown' — strict opt-out means False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info(**_kwargs):
|
||||
return {"max_input_tokens": 8192} # no supports_vision
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="any-model",
|
||||
)
|
||||
is False
|
||||
)
|
||||
345
surfsense_backend/tests/unit/services/test_quality_score.py
Normal file
345
surfsense_backend/tests/unit/services/test_quality_score.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""Unit tests for the Auto (Fastest) quality scoring module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.quality_score import (
|
||||
_HEALTH_GATE_UPTIME_PCT,
|
||||
_OPERATOR_TRUST_BONUS,
|
||||
aggregate_health,
|
||||
capabilities_signal,
|
||||
context_signal,
|
||||
created_recency_signal,
|
||||
pricing_band,
|
||||
slug_penalty,
|
||||
static_score_or,
|
||||
static_score_yaml,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# created_recency_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_created_recency_signal_recent_model_scores_high():
|
||||
now = 1_750_000_000 # ~mid-2025
|
||||
one_month_ago = now - (30 * 86_400)
|
||||
assert created_recency_signal(one_month_ago, now) == 20
|
||||
|
||||
|
||||
def test_created_recency_signal_old_model_scores_zero():
|
||||
now = 1_750_000_000
|
||||
five_years_ago = now - (5 * 365 * 86_400)
|
||||
assert created_recency_signal(five_years_ago, now) == 0
|
||||
|
||||
|
||||
def test_created_recency_signal_missing_timestamp_is_neutral():
|
||||
now = 1_750_000_000
|
||||
assert created_recency_signal(None, now) == 0
|
||||
assert created_recency_signal(0, now) == 0
|
||||
|
||||
|
||||
def test_created_recency_signal_monotonic_decay():
|
||||
now = 1_750_000_000
|
||||
scores = [
|
||||
created_recency_signal(now - days * 86_400, now)
|
||||
for days in (30, 120, 300, 500, 700, 1000, 1500)
|
||||
]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pricing_band
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_pricing_band_free_returns_zero():
|
||||
assert pricing_band("0", "0") == 0
|
||||
assert pricing_band(0.0, 0.0) == 0
|
||||
assert pricing_band(None, None) == 0
|
||||
|
||||
|
||||
def test_pricing_band_handles_unparseable():
|
||||
assert pricing_band("not-a-number", "0") == 0
|
||||
assert pricing_band({}, []) == 0 # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_pricing_band_premium_tiers_increase_with_price():
|
||||
cheap = pricing_band("0.0000003", "0.0000005")
|
||||
mid = pricing_band("0.000003", "0.000015")
|
||||
flagship = pricing_band("0.00001", "0.00005")
|
||||
assert 0 < cheap < mid < flagship
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# context_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ctx,expected",
|
||||
[
|
||||
(1_500_000, 10),
|
||||
(1_000_000, 10),
|
||||
(500_000, 8),
|
||||
(200_000, 6),
|
||||
(128_000, 4),
|
||||
(100_000, 2),
|
||||
(50_000, 0),
|
||||
(0, 0),
|
||||
(None, 0),
|
||||
],
|
||||
)
|
||||
def test_context_signal_bands(ctx, expected):
|
||||
assert context_signal(ctx) == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# capabilities_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_capabilities_signal_caps_at_five():
|
||||
assert (
|
||||
capabilities_signal(
|
||||
["tools", "structured_outputs", "reasoning", "include_reasoning"]
|
||||
)
|
||||
<= 5
|
||||
)
|
||||
|
||||
|
||||
def test_capabilities_signal_tools_only():
|
||||
assert capabilities_signal(["tools"]) == 2
|
||||
|
||||
|
||||
def test_capabilities_signal_empty():
|
||||
assert capabilities_signal(None) == 0
|
||||
assert capabilities_signal([]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# slug_penalty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_slug_penalty_demotes_tiny_models():
|
||||
assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
|
||||
assert slug_penalty("liquid/lfm-7b") < 0
|
||||
assert slug_penalty("google/gemma-3n-e4b-it") < 0
|
||||
|
||||
|
||||
def test_slug_penalty_skips_capable_mini_nano_lite_models():
|
||||
"""Critical Option C+ regression: don't penalise modern frontier
|
||||
models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
|
||||
assert slug_penalty("openai/gpt-5-mini") == 0
|
||||
assert slug_penalty("openai/gpt-5-nano") == 0
|
||||
assert slug_penalty("google/gemini-2.5-flash-lite") == 0
|
||||
assert slug_penalty("anthropic/claude-haiku-4.5") == 0
|
||||
|
||||
|
||||
def test_slug_penalty_demotes_legacy_variants():
|
||||
assert slug_penalty("openai/o1-preview") < 0
|
||||
assert slug_penalty("foo/bar-base") < 0
|
||||
assert slug_penalty("foo/bar-distill") < 0
|
||||
|
||||
|
||||
def test_slug_penalty_empty_input():
|
||||
assert slug_penalty("") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# static_score_or
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _or_model(
|
||||
*,
|
||||
model_id: str,
|
||||
created: int | None = None,
|
||||
prompt: str = "0.000003",
|
||||
completion: str = "0.000015",
|
||||
context: int = 200_000,
|
||||
params: list[str] | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"id": model_id,
|
||||
"created": created,
|
||||
"pricing": {"prompt": prompt, "completion": completion},
|
||||
"context_length": context,
|
||||
"supported_parameters": params if params is not None else ["tools"],
|
||||
}
|
||||
|
||||
|
||||
def test_static_score_or_frontier_premium_beats_free_tiny():
|
||||
now = 1_750_000_000
|
||||
frontier = _or_model(
|
||||
model_id="openai/gpt-5",
|
||||
created=now - (60 * 86_400),
|
||||
prompt="0.000005",
|
||||
completion="0.000020",
|
||||
context=400_000,
|
||||
params=["tools", "structured_outputs", "reasoning"],
|
||||
)
|
||||
tiny_free = _or_model(
|
||||
model_id="meta-llama/llama-3.2-1b-instruct:free",
|
||||
created=now - (5 * 365 * 86_400),
|
||||
prompt="0",
|
||||
completion="0",
|
||||
context=128_000,
|
||||
params=["tools"],
|
||||
)
|
||||
assert static_score_or(frontier, now_ts=now) > static_score_or(
|
||||
tiny_free, now_ts=now
|
||||
)
|
||||
|
||||
|
||||
def test_static_score_or_score_is_clamped_0_to_100():
|
||||
now = int(time.time())
|
||||
score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
|
||||
assert 0 <= score <= 100
|
||||
|
||||
|
||||
def test_static_score_or_unknown_provider_is_neutral_not_zero():
|
||||
now = int(time.time())
|
||||
score = static_score_or(
|
||||
_or_model(model_id="some-new-lab/some-model"),
|
||||
now_ts=now,
|
||||
)
|
||||
assert score > 0
|
||||
|
||||
|
||||
def test_static_score_or_recent_release_beats_year_old_same_provider():
|
||||
now = 1_750_000_000
|
||||
fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
|
||||
old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
|
||||
assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# static_score_yaml
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_static_score_yaml_includes_operator_bonus():
|
||||
cfg = {
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||
}
|
||||
score = static_score_yaml(cfg)
|
||||
assert score >= _OPERATOR_TRUST_BONUS
|
||||
|
||||
|
||||
def test_static_score_yaml_unknown_provider_still_carries_bonus():
|
||||
cfg = {
|
||||
"provider": "SOME_NEW_PROVIDER",
|
||||
"model_name": "weird-model",
|
||||
}
|
||||
score = static_score_yaml(cfg)
|
||||
assert score >= _OPERATOR_TRUST_BONUS
|
||||
|
||||
|
||||
def test_static_score_yaml_clamped_0_to_100():
|
||||
cfg = {
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5",
|
||||
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||
}
|
||||
assert 0 <= static_score_yaml(cfg) <= 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# aggregate_health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_aggregate_health_gates_when_uptime_below_threshold():
|
||||
"""Live data showed Venice-routed cfgs at 53-68%; this guards that the
|
||||
90% gate excludes them."""
|
||||
venice_endpoints = [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.55,
|
||||
"uptime_last_1d": 0.60,
|
||||
"uptime_last_5m": 0.50,
|
||||
},
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.65,
|
||||
"uptime_last_1d": 0.68,
|
||||
"uptime_last_5m": 0.62,
|
||||
},
|
||||
]
|
||||
gated, score = aggregate_health(venice_endpoints)
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_passes_for_healthy_provider():
|
||||
healthy = [
|
||||
{
|
||||
"status": 0,
|
||||
"uptime_last_30m": 0.99,
|
||||
"uptime_last_1d": 0.995,
|
||||
"uptime_last_5m": 0.99,
|
||||
},
|
||||
]
|
||||
gated, score = aggregate_health(healthy)
|
||||
assert gated is False
|
||||
assert score is not None
|
||||
assert score >= _HEALTH_GATE_UPTIME_PCT
|
||||
|
||||
|
||||
def test_aggregate_health_picks_best_endpoint_across_multiple():
|
||||
"""Multi-endpoint aggregation should reward the best non-null uptime."""
|
||||
mixed = [
|
||||
{"status": 0, "uptime_last_30m": 0.55},
|
||||
{"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
|
||||
]
|
||||
gated, score = aggregate_health(mixed)
|
||||
assert gated is False
|
||||
assert score is not None
|
||||
|
||||
|
||||
def test_aggregate_health_empty_endpoints_gated():
|
||||
gated, score = aggregate_health([])
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_no_status_zero_gated():
|
||||
"""Even with high uptime, no OK status means the cfg is broken upstream."""
|
||||
endpoints = [
|
||||
{"status": 1, "uptime_last_30m": 0.99},
|
||||
{"status": 2, "uptime_last_30m": 0.98},
|
||||
]
|
||||
gated, score = aggregate_health(endpoints)
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_all_uptime_null_gated():
|
||||
endpoints = [
|
||||
{"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
|
||||
]
|
||||
gated, score = aggregate_health(endpoints)
|
||||
assert gated is True
|
||||
assert score is None
|
||||
|
||||
|
||||
def test_aggregate_health_pct_normalisation():
|
||||
"""OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
|
||||
percentages. Both should reach the same gate decision."""
|
||||
fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
|
||||
pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
|
||||
g1, s1 = aggregate_health(fraction_form)
|
||||
g2, s2 = aggregate_health(pct_form)
|
||||
assert g1 == g2 == False # noqa: E712
|
||||
assert s1 is not None and s2 is not None
|
||||
assert abs(s1 - s2) < 0.5
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
"""Unit tests for ``QuotaCheckedVisionLLM``.
|
||||
|
||||
Validates that:
|
||||
|
||||
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
|
||||
enforcement) and forwards the inner LLM's response on success.
|
||||
* The wrapper proxies non-overridden attributes to the inner LLM
|
||||
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
|
||||
still work without quota gating (they're not used in indexing today).
|
||||
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
|
||||
bubbles it up — the ETL pipeline catches that and falls back to OCR.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeInnerLLM:
|
||||
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
|
||||
|
||||
def __init__(self, response: Any = "OCR'd content") -> None:
|
||||
self._response = response
|
||||
self.ainvoke_calls: list[Any] = []
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
self.ainvoke_calls.append(input)
|
||||
return self._response
|
||||
|
||||
def some_other_method(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _passthrough_billable_call(**_kwargs):
|
||||
"""Stand-in for billable_call that always allows the call to run."""
|
||||
|
||||
class _Acc:
|
||||
total_cost_micros = 0
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
grand_total = 0
|
||||
calls: list[Any] = []
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
return {}
|
||||
|
||||
yield _Acc()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_routes_through_billable_call(monkeypatch):
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
captured_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _spy_billable_call(**kwargs):
|
||||
captured_kwargs.append(kwargs)
|
||||
async with _passthrough_billable_call() as acc:
|
||||
yield acc
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_spy_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM(response="A red apple on a white table")
|
||||
user_id = uuid4()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=user_id,
|
||||
search_space_id=99,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
result = await wrapper.ainvoke([{"text": "what is this?"}])
|
||||
assert result == "A red apple on a white table"
|
||||
assert len(inner.ainvoke_calls) == 1
|
||||
assert len(captured_kwargs) == 1
|
||||
bc_kwargs = captured_kwargs[0]
|
||||
assert bc_kwargs["user_id"] == user_id
|
||||
assert bc_kwargs["search_space_id"] == 99
|
||||
assert bc_kwargs["billing_tier"] == "premium"
|
||||
assert bc_kwargs["base_model"] == "openai/gpt-4o"
|
||||
assert bc_kwargs["quota_reserve_tokens"] == 4000
|
||||
assert bc_kwargs["usage_type"] == "vision_extraction"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="vision_extraction",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # unreachable but required for asynccontextmanager type
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_denying_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaInsufficientError):
|
||||
await wrapper.ainvoke([{"text": "x"}])
|
||||
|
||||
# Inner LLM never ran on a denied reservation.
|
||||
assert inner.ainvoke_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxies_non_overridden_attributes_to_inner():
|
||||
"""``__getattr__`` forwards anything not on the proxy itself, so any
|
||||
method we didn't explicitly override (``invoke``, ``astream``,
|
||||
``with_structured_output``, etc.) still works — just without quota
|
||||
gating, which is fine because the indexer only ever calls ainvoke.
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
# ``some_other_method`` is on the inner only.
|
||||
assert wrapper.some_other_method(7) == 14
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
|
||||
|
||||
Capability is sourced from two places, in order of preference:
|
||||
|
||||
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
|
||||
(authoritative — OpenRouter publishes per-model modalities directly).
|
||||
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
|
||||
YAML / BYOK configs that don't carry an explicit operator override.
|
||||
|
||||
The catalog default is *True* (conservative-allow): an unknown / unmapped
|
||||
model is not pre-judged. The streaming-task safety net
|
||||
(``is_known_text_only_chat_model``) is the only place a False actually
|
||||
blocks a request — and it requires LiteLLM to *explicitly* mark the model
|
||||
as text-only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.openrouter_integration_service import (
|
||||
_OPENROUTER_DYNAMIC_MARKER,
|
||||
_generate_configs,
|
||||
_supports_image_input,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
_SETTINGS_BASE: dict = {
|
||||
"api_key": "sk-or-test",
|
||||
"id_offset": -10_000,
|
||||
"rpm": 200,
|
||||
"tpm": 1_000_000,
|
||||
"free_rpm": 20,
|
||||
"free_tpm": 100_000,
|
||||
"anonymous_enabled_paid": False,
|
||||
"anonymous_enabled_free": True,
|
||||
"quota_reserve_tokens": 4000,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _supports_image_input helper (OpenRouter modalities)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_supports_image_input_true_for_multimodal():
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_for_text_only():
|
||||
"""The exact failure mode the safety net guards against — DeepSeek V3
|
||||
is a text-in/text-out model and would 404 if forwarded image_url."""
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_supports_image_input_false_when_modalities_missing():
|
||||
"""Defensive: missing architecture is treated as text-only at the
|
||||
OpenRouter helper level. The wider catalog resolver
|
||||
(`derive_supports_image_input`) only consults modalities when they
|
||||
are non-empty, otherwise it falls back to LiteLLM."""
|
||||
assert _supports_image_input({"id": "weird/model"}) is False
|
||||
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
|
||||
assert (
|
||||
_supports_image_input(
|
||||
{"id": "weird/model", "architecture": {"input_modalities": None}}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_configs threads the flag onto every emitted chat config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_configs_emits_supports_image_input():
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
{
|
||||
"id": "deepseek/deepseek-v3.2-exp",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": ["tools"],
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||
by_model = {c["model_name"]: c for c in cfgs}
|
||||
|
||||
gpt = by_model["openai/gpt-4o"]
|
||||
assert gpt["supports_image_input"] is True
|
||||
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
|
||||
assert deepseek["supports_image_input"] is False
|
||||
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML loader: defer to derive_supports_image_input on unannotated entries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
|
||||
"""The regression case: an Azure GPT-5.x YAML entry without a
|
||||
``supports_image_input`` override should resolve to True via LiteLLM's
|
||||
model map (which says ``supports_vision: true``). Previously this
|
||||
defaulted to False, blocking every image turn for vision-capable
|
||||
YAML configs."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -2
|
||||
name: Azure GPT-4o
|
||||
provider: AZURE_OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: GPT-4o
|
||||
provider: OPENAI
|
||||
model_name: gpt-4o
|
||||
api_key: sk-test
|
||||
supports_image_input: false
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
# Operator override always wins, even against LiteLLM's True.
|
||||
assert configs[0]["supports_image_input"] is False
|
||||
|
||||
|
||||
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
|
||||
"""Unknown / unmapped model in YAML: default-allow. The streaming
|
||||
safety net (which requires an explicit-False from LiteLLM) is the
|
||||
only place a real block happens, so we don't lock the user out of
|
||||
a freshly added third-party entry the catalog can't introspect."""
|
||||
yaml_dir = tmp_path / "app" / "config"
|
||||
yaml_dir.mkdir(parents=True)
|
||||
(yaml_dir / "global_llm_config.yaml").write_text(
|
||||
"""
|
||||
global_llm_configs:
|
||||
- id: -1
|
||||
name: Some Brand New Model
|
||||
provider: CUSTOM
|
||||
custom_provider: brand_new_proxy
|
||||
model_name: brand-new-model-x9
|
||||
api_key: sk-test
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from app import config as config_module
|
||||
|
||||
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||
|
||||
configs = config_module.load_global_llm_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0]["supports_image_input"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentConfig threads the flag through both YAML and Auto / BYOK
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_explicit_overrides_resolver():
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg_text_only = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "Text Only Override",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": False,
|
||||
}
|
||||
)
|
||||
cfg_explicit_vision = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -2,
|
||||
"name": "GPT-4o",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
"supports_image_input": True,
|
||||
}
|
||||
)
|
||||
assert cfg_text_only.supports_image_input is False
|
||||
assert cfg_explicit_vision.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_from_yaml_unannotated_uses_resolver():
|
||||
"""Without an explicit YAML key, AgentConfig defers to the catalog
|
||||
resolver — for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
cfg = AgentConfig.from_yaml_config(
|
||||
{
|
||||
"id": -1,
|
||||
"name": "GPT-4o (no override)",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "sk-test",
|
||||
}
|
||||
)
|
||||
assert cfg.supports_image_input is True
|
||||
|
||||
|
||||
def test_agent_config_auto_mode_supports_image_input():
|
||||
"""Auto routes across the pool. We optimistically allow image input
|
||||
so users can keep their selection on Auto with a vision-capable
|
||||
deployment somewhere in the pool. The router's own `allowed_fails`
|
||||
handles non-vision deployments via fallback."""
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
|
||||
auto = AgentConfig.from_auto_mode()
|
||||
assert auto.supports_image_input is True
|
||||
|
|
@ -0,0 +1,515 @@
|
|||
"""Cost-based premium quota unit tests.
|
||||
|
||||
Covers the USD-micro behaviour added in migration 140:
|
||||
|
||||
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
|
||||
calls in a turn — used as the debit amount when ``agent_config.is_premium``
|
||||
is true, regardless of which underlying model produced each call. This
|
||||
preserves the prior "premium turn → all calls in turn count" rule from the
|
||||
token-based system.
|
||||
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
|
||||
clamps to a sane floor when pricing is unknown, and respects the
|
||||
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
|
||||
can't lock the whole balance on one call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TurnTokenAccumulator — premium-turn debit semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_total_cost_micros_sums_premium_and_free_calls():
|
||||
"""A premium turn that also called a free sub-agent debits the union.
|
||||
|
||||
The plan deliberately preserved the existing "premium turn → all calls
|
||||
count" behaviour because per-call premium filtering relied on
|
||||
``LLMRouterService._premium_model_strings`` which only covers router-pool
|
||||
deployments. ``total_cost_micros`` therefore must include free-model
|
||||
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
|
||||
call's actual provider cost.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
# Premium model (e.g. claude-opus): non-zero cost.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=1200,
|
||||
completion_tokens=400,
|
||||
total_tokens=1600,
|
||||
cost_micros=12_345,
|
||||
)
|
||||
# Free sub-agent (e.g. title-gen on a free model): zero cost.
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=120,
|
||||
completion_tokens=20,
|
||||
total_tokens=140,
|
||||
cost_micros=0,
|
||||
)
|
||||
# A second premium-priced call within the same turn.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=800,
|
||||
completion_tokens=200,
|
||||
total_tokens=1000,
|
||||
cost_micros=7_500,
|
||||
)
|
||||
|
||||
assert acc.total_cost_micros == 12_345 + 0 + 7_500
|
||||
# Token totals stay correct so the FE display path still works.
|
||||
assert acc.grand_total == 1600 + 140 + 1000
|
||||
|
||||
|
||||
def test_total_cost_micros_zero_when_no_calls():
|
||||
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
assert acc.total_cost_micros == 0
|
||||
assert acc.grand_total == 0
|
||||
|
||||
|
||||
def test_per_message_summary_groups_cost_by_model():
|
||||
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
|
||||
SSE ``model_breakdown`` payload reports actual USD spend per provider.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_micros=4_000,
|
||||
)
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
total_tokens=300,
|
||||
cost_micros=8_000,
|
||||
)
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
total_tokens=60,
|
||||
cost_micros=200,
|
||||
)
|
||||
|
||||
summary = acc.per_message_summary()
|
||||
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
|
||||
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
|
||||
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||
|
||||
|
||||
def test_serialized_calls_includes_cost_micros():
|
||||
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||
payload; cost_micros must be present on each entry so the FE message-info
|
||||
dropdown can render per-call USD.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="m",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=42,
|
||||
)
|
||||
serialized = acc.serialized_calls()
|
||||
assert serialized == [
|
||||
{
|
||||
"model": "m",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost_micros": 42,
|
||||
"call_kind": "chat",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_call_reserve_micros — sizing and clamping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
|
||||
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
|
||||
helper falls back to the 100-micro floor — small enough that a user with
|
||||
$0.0001 left can still send a tiny request, but non-zero so we still gate
|
||||
against an empty balance.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
def _raise(_name):
|
||||
raise KeyError("unknown")
|
||||
|
||||
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="nonexistent-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
assert micros == 100
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
|
||||
"""LiteLLM may *return* a model with both cost-per-token fields at 0
|
||||
(pricing not yet registered). The helper must not multiply 0 x tokens
|
||||
and end up reserving 0 — it must clamp to the floor.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
|
||||
raising=False,
|
||||
)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="some-pending-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
|
||||
def test_reserve_scales_with_model_cost(monkeypatch):
|
||||
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
|
||||
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
|
||||
some small artificial cap — that was the bug the plan called out.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 15e-6,
|
||||
"output_cost_per_token": 75e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="claude-3-opus",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
|
||||
assert micros == 360_000
|
||||
|
||||
|
||||
def test_reserve_clamps_to_max_ceiling(monkeypatch):
|
||||
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
|
||||
nominally compute to $4 = 4_000_000 micros. The ceiling
|
||||
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
|
||||
can't lock the user's whole balance on one call.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-3,
|
||||
"output_cost_per_token": 0,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="oops-misconfigured",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == 1_000_000
|
||||
|
||||
|
||||
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
|
||||
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
|
||||
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
|
||||
so anonymous-style configs still reserve the operator-tunable default.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 1e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=None
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=0
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TokenTrackingCallback — image vs chat usage shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeImageUsage:
|
||||
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
total_tokens: int | None = None,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
if total_tokens is not None:
|
||||
self.total_tokens = total_tokens
|
||||
|
||||
|
||||
class _FakeImageResponse:
|
||||
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
|
||||
``type(...).__name__`` probe routes to the image branch.
|
||||
"""
|
||||
|
||||
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
|
||||
self.usage = usage
|
||||
if response_cost is not None:
|
||||
self._hidden_params = {"response_cost": response_cost}
|
||||
|
||||
|
||||
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
|
||||
# the callback. We can't simply name the class ``ImageResponse`` because
|
||||
# the test runner sometimes imports test modules in surprising ways and
|
||||
# we want to be explicit.
|
||||
_FakeImageResponse.__name__ = "ImageResponse"
|
||||
|
||||
|
||||
class _FakeChatUsage:
|
||||
def __init__(self, prompt: int, completion: int):
|
||||
self.prompt_tokens = prompt
|
||||
self.completion_tokens = completion
|
||||
self.total_tokens = prompt + completion
|
||||
|
||||
|
||||
class _FakeChatResponse:
|
||||
def __init__(self, usage: _FakeChatUsage):
|
||||
self.usage = usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_reads_image_usage_input_output_tokens():
|
||||
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
|
||||
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
|
||||
prompt_tokens/completion_tokens which is the chat shape.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
|
||||
response_cost=0.04, # $0.04 per image
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 42
|
||||
assert call.completion_tokens == 8
|
||||
assert call.total_tokens == 50
|
||||
# 0.04 USD = 40_000 micros
|
||||
assert call.cost_micros == 40_000
|
||||
assert call.call_kind == "image_generation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_chat_path_unchanged():
|
||||
"""Chat responses must still read prompt_tokens/completion_tokens."""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={
|
||||
"model": "openrouter/anthropic/claude-3-5-sonnet",
|
||||
"response_cost": 0.0036,
|
||||
},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 120
|
||||
assert call.completion_tokens == 30
|
||||
assert call.total_tokens == 150
|
||||
assert call.cost_micros == 3_600
|
||||
assert call.call_kind == "chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
|
||||
"""When OpenRouter omits ``usage.cost`` LiteLLM's
|
||||
``default_image_cost_calculator`` raises. The defensive image branch in
|
||||
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
|
||||
chat-shaped and would raise too) — it returns 0 with a WARNING log.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# Force completion_cost to raise the same way OpenRouter image-gen fails.
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise ValueError("model_cost: missing entry for openrouter image model")
|
||||
|
||||
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
|
||||
|
||||
# And make sure cost_per_token is NEVER called for the image path —
|
||||
# if it were, our ``is_image=True`` branch is broken.
|
||||
cost_per_token_calls: list = []
|
||||
|
||||
def _record_cost_per_token(**kwargs):
|
||||
cost_per_token_calls.append(kwargs)
|
||||
return (0.0, 0.0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm, "cost_per_token", _record_cost_per_token, raising=False
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
|
||||
assert len(acc.calls) == 1
|
||||
assert acc.calls[0].cost_micros == 0
|
||||
assert acc.calls[0].call_kind == "image_generation"
|
||||
# The image branch must short-circuit before cost_per_token.
|
||||
assert cost_per_token_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scoped_turn — ContextVar reset semantics (issue B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_restores_outer_accumulator():
|
||||
"""``scoped_turn`` must restore the previous ContextVar value on exit
|
||||
so a per-call wrapper inside an outer chat turn doesn't leak its
|
||||
accumulator outward (which would cause double-debit at chat-turn exit).
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
start_turn,
|
||||
)
|
||||
|
||||
outer = start_turn()
|
||||
assert get_current_accumulator() is outer
|
||||
|
||||
async with scoped_turn() as inner:
|
||||
assert get_current_accumulator() is inner
|
||||
assert inner is not outer
|
||||
inner.add(
|
||||
model="x",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=5,
|
||||
)
|
||||
|
||||
# After exit the outer accumulator is restored unchanged.
|
||||
assert get_current_accumulator() is outer
|
||||
assert outer.total_cost_micros == 0
|
||||
assert len(outer.calls) == 0
|
||||
# The inner accumulator captured the call but didn't bleed into outer.
|
||||
assert inner.total_cost_micros == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_resets_to_none_when_no_outer():
|
||||
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
|
||||
indexing job) must leave the ContextVar at ``None`` on exit so the
|
||||
next *unrelated* request starts clean.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
_turn_accumulator,
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# ContextVar default is None for a fresh test isolated context. We
|
||||
# simulate "no outer" explicitly to be robust against test order.
|
||||
token = _turn_accumulator.set(None)
|
||||
try:
|
||||
assert get_current_accumulator() is None
|
||||
async with scoped_turn() as acc:
|
||||
assert get_current_accumulator() is acc
|
||||
assert get_current_accumulator() is None
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
|
||||
defaults from ``litellm.api_base`` either.
|
||||
|
||||
Vision shares the same shape as image-gen — global YAML / OpenRouter
|
||||
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
|
||||
call sites would silently drop the empty string and inherit
|
||||
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
|
||||
construction so we test the kwargs we hand to it instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vision_llm_global_openrouter_sets_api_base():
|
||||
"""Global negative-ID branch: an OpenRouter vision config with
|
||||
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
|
||||
``api_base="https://openrouter.ai/api/v1"`` — never an empty string,
|
||||
never silently absent."""
|
||||
from app.services import llm_service
|
||||
|
||||
cfg = {
|
||||
"id": -30_001,
|
||||
"name": "GPT-4o Vision (OpenRouter)",
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
"api_version": None,
|
||||
"litellm_params": {},
|
||||
"billing_tier": "free",
|
||||
}
|
||||
|
||||
search_space = MagicMock()
|
||||
search_space.id = 1
|
||||
search_space.user_id = "user-x"
|
||||
search_space.vision_llm_config_id = cfg["id"]
|
||||
|
||||
session = AsyncMock()
|
||||
scalars = MagicMock()
|
||||
scalars.first.return_value = search_space
|
||||
result = MagicMock()
|
||||
result.scalars.return_value = scalars
|
||||
session.execute.return_value = result
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class FakeSanitized:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.vision_llm_router_service.get_global_vision_llm_config",
|
||||
return_value=cfg,
|
||||
),
|
||||
patch(
|
||||
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
|
||||
new=FakeSanitized,
|
||||
),
|
||||
):
|
||||
await llm_service.get_vision_llm(session=session, search_space_id=1)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-4o"
|
||||
|
||||
|
||||
def test_vision_router_deployment_sets_api_base_when_config_empty():
|
||||
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
|
||||
so the resolver has to apply at deployment construction time too."""
|
||||
from app.services.vision_llm_router_service import VisionLLMRouterService
|
||||
|
||||
deployment = VisionLLMRouterService._config_to_deployment(
|
||||
{
|
||||
"model_name": "openai/gpt-4o",
|
||||
"provider": "OPENROUTER",
|
||||
"api_key": "sk-or-test",
|
||||
"api_base": "",
|
||||
}
|
||||
)
|
||||
assert deployment is not None
|
||||
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
|
||||
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"
|
||||
526
surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
Normal file
526
surfsense_backend/tests/unit/tasks/chat/test_content_builder.py
Normal file
|
|
@ -0,0 +1,526 @@
|
|||
"""Unit tests for ``AssistantContentBuilder``.
|
||||
|
||||
Pins the in-memory ``ContentPart[]`` projection so the JSONB the server
|
||||
persists matches what the frontend renders live (see
|
||||
``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both
|
||||
the structural shape of ``snapshot()`` and that the snapshot is
|
||||
``json.dumps``-safe (the streaming finally block writes it directly to
|
||||
``new_chat_messages.content`` without an explicit serialization round
|
||||
trip).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.chat.content_builder import AssistantContentBuilder
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_jsonb_safe(parts: list[dict]) -> None:
|
||||
"""Sanity check: any snapshot must round-trip through ``json.dumps``."""
|
||||
serialized = json.dumps(parts)
|
||||
assert json.loads(serialized) == parts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text turns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextOnly:
|
||||
def test_single_text_block_collapses_consecutive_deltas(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "Hello")
|
||||
b.on_text_delta("text-1", " ")
|
||||
b.on_text_delta("text-1", "world")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [{"type": "text", "text": "Hello world"}]
|
||||
assert not b.is_empty()
|
||||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_empty_text_start_end_pair_leaves_no_part(self):
|
||||
# Mirrors the FE: a text-start without any deltas should
|
||||
# not materialise an empty ``{"type":"text","text":""}`` part.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
assert b.snapshot() == []
|
||||
assert b.is_empty()
|
||||
|
||||
def test_text_after_text_end_starts_fresh_part(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "first")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_text_start("text-2")
|
||||
b.on_text_delta("text-2", "second")
|
||||
b.on_text_end("text-2")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
|
||||
|
||||
class TestReasoningThenText:
|
||||
def test_reasoning_followed_by_text_yields_two_parts_in_order(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_reasoning_start("r-1")
|
||||
b.on_reasoning_delta("r-1", "Considering options...")
|
||||
b.on_reasoning_end("r-1")
|
||||
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "The answer is 42.")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [
|
||||
{"type": "reasoning", "text": "Considering options..."},
|
||||
{"type": "text", "text": "The answer is 42."},
|
||||
]
|
||||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_text_delta_after_reasoning_implicitly_closes_reasoning(self):
|
||||
# Mirrors FE ``appendText``: a text delta arriving while a
|
||||
# reasoning part is "active" still produces a fresh text
|
||||
# part, never appends into the reasoning block.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_reasoning_start("r-1")
|
||||
b.on_reasoning_delta("r-1", "thinking")
|
||||
# No explicit reasoning_end — text delta should close it.
|
||||
b.on_text_delta("text-1", "answer")
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [
|
||||
{"type": "reasoning", "text": "thinking"},
|
||||
{"type": "text", "text": "answer"},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool calls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolHeavyTurn:
|
||||
def test_full_tool_lifecycle_produces_complete_tool_call_part(self):
|
||||
b = AssistantContentBuilder()
|
||||
# Some narration before the tool fires.
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "Searching...")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_tool_input_start(
|
||||
ui_id="call_run123",
|
||||
tool_name="web_search",
|
||||
langchain_tool_call_id="lc_tool_abc",
|
||||
)
|
||||
b.on_tool_input_delta("call_run123", '{"query":')
|
||||
b.on_tool_input_delta("call_run123", '"surfsense"}')
|
||||
b.on_tool_input_available(
|
||||
ui_id="call_run123",
|
||||
tool_name="web_search",
|
||||
args={"query": "surfsense"},
|
||||
langchain_tool_call_id="lc_tool_abc",
|
||||
)
|
||||
b.on_tool_output_available(
|
||||
ui_id="call_run123",
|
||||
output={"status": "completed", "citations": {}},
|
||||
langchain_tool_call_id="lc_tool_abc",
|
||||
)
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap[0] == {"type": "text", "text": "Searching..."}
|
||||
tool_part = snap[1]
|
||||
assert tool_part["type"] == "tool-call"
|
||||
assert tool_part["toolCallId"] == "call_run123"
|
||||
assert tool_part["toolName"] == "web_search"
|
||||
assert tool_part["args"] == {"query": "surfsense"}
|
||||
# ``argsText`` is the pretty-printed final JSON, not the raw
|
||||
# streaming buffer (FE ``stream-pipeline.ts:128``).
|
||||
assert tool_part["argsText"] == json.dumps(
|
||||
{"query": "surfsense"}, indent=2, ensure_ascii=False
|
||||
)
|
||||
assert tool_part["langchainToolCallId"] == "lc_tool_abc"
|
||||
assert tool_part["result"] == {"status": "completed", "citations": {}}
|
||||
_assert_jsonb_safe(snap)
|
||||
|
||||
def test_tool_input_available_without_prior_start_creates_card(self):
|
||||
# Legacy / parity_v2-OFF path: tool-input-available may be
|
||||
# emitted without a prior tool-input-start (no streamed
|
||||
# tool_call_chunks). The card should still be created.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_available(
|
||||
ui_id="call_run42",
|
||||
tool_name="grep",
|
||||
args={"pattern": "TODO"},
|
||||
langchain_tool_call_id="lc_x",
|
||||
)
|
||||
b.on_tool_output_available(
|
||||
ui_id="call_run42",
|
||||
output={"matches": 3},
|
||||
langchain_tool_call_id="lc_x",
|
||||
)
|
||||
|
||||
snap = b.snapshot()
|
||||
assert len(snap) == 1
|
||||
part = snap[0]
|
||||
assert part["type"] == "tool-call"
|
||||
assert part["toolCallId"] == "call_run42"
|
||||
assert part["args"] == {"pattern": "TODO"}
|
||||
assert part["langchainToolCallId"] == "lc_x"
|
||||
assert part["result"] == {"matches": 3}
|
||||
|
||||
def test_tool_input_start_idempotent_for_same_ui_id(self):
|
||||
# parity_v2: tool-input-start can fire from BOTH the chunk
|
||||
# registration path AND the canonical ``on_tool_start`` path.
|
||||
# The second call must not create a duplicate part.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
snap = b.snapshot()
|
||||
assert len(snap) == 1
|
||||
|
||||
def test_tool_input_delta_without_prior_start_is_silently_dropped(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}')
|
||||
assert b.snapshot() == []
|
||||
|
||||
def test_langchain_tool_call_id_backfills_only_when_absent(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", "lc_first")
|
||||
# Late event must NOT clobber an already-set lc id.
|
||||
b.on_tool_input_start("call_x", "ls", "lc_late")
|
||||
snap = b.snapshot()
|
||||
assert snap[0]["langchainToolCallId"] == "lc_first"
|
||||
|
||||
def test_args_text_streaming_buffer_reflects_concatenation(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "save_doc", "lc_y")
|
||||
b.on_tool_input_delta("call_x", '{"title":')
|
||||
b.on_tool_input_delta("call_x", '"Hi"}')
|
||||
# Snapshot mid-stream should see the partial buffer (the FE
|
||||
# tolerates invalid JSON and renders it as-is).
|
||||
mid = b.snapshot()
|
||||
assert mid[0]["argsText"] == '{"title":"Hi"}'
|
||||
# Then tool-input-available replaces with pretty-printed.
|
||||
b.on_tool_input_available(
|
||||
"call_x",
|
||||
"save_doc",
|
||||
{"title": "Hi"},
|
||||
"lc_y",
|
||||
)
|
||||
final = b.snapshot()
|
||||
assert final[0]["argsText"] == json.dumps(
|
||||
{"title": "Hi"}, indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking steps & separators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThinkingSteps:
|
||||
def test_first_thinking_step_unshifts_singleton_to_index_zero(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "Hello")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"])
|
||||
|
||||
snap = b.snapshot()
|
||||
# Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift).
|
||||
assert snap[0]["type"] == "data-thinking-steps"
|
||||
assert snap[0]["data"]["steps"] == [
|
||||
{
|
||||
"id": "step-1",
|
||||
"title": "Analyzing",
|
||||
"status": "in_progress",
|
||||
"items": ["item-a"],
|
||||
}
|
||||
]
|
||||
assert snap[1] == {"type": "text", "text": "Hello"}
|
||||
|
||||
def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("step-1", "Analyzing", "in_progress", [])
|
||||
b.on_thinking_step("step-2", "Searching", "in_progress", ["q"])
|
||||
b.on_thinking_step("step-1", "Analyzing", "completed", ["done"])
|
||||
|
||||
snap = b.snapshot()
|
||||
assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1
|
||||
steps = snap[0]["data"]["steps"]
|
||||
assert len(steps) == 2
|
||||
assert steps[0]["id"] == "step-1"
|
||||
assert steps[0]["status"] == "completed"
|
||||
assert steps[0]["items"] == ["done"]
|
||||
assert steps[1]["id"] == "step-2"
|
||||
|
||||
def test_thinking_step_with_text_continues_appending_to_text(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "first")
|
||||
|
||||
# Thinking step inserts at index 0, bumps text idx from 0 to 1.
|
||||
b.on_thinking_step("step-1", "Working", "in_progress", [])
|
||||
b.on_text_delta("text-1", " second")
|
||||
|
||||
snap = b.snapshot()
|
||||
text_parts = [p for p in snap if p["type"] == "text"]
|
||||
assert text_parts == [{"type": "text", "text": "first second"}]
|
||||
|
||||
def test_thinking_step_without_id_is_dropped(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("", "noop", "in_progress", None)
|
||||
assert b.snapshot() == []
|
||||
assert b.is_empty()
|
||||
|
||||
|
||||
class TestStepSeparators:
|
||||
def test_separator_no_op_before_any_content(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_step_separator()
|
||||
assert b.snapshot() == []
|
||||
|
||||
def test_separator_after_text_appends_with_step_index_zero(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "first")
|
||||
b.on_text_end("text-1")
|
||||
|
||||
b.on_step_separator()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap[-1] == {
|
||||
"type": "data-step-separator",
|
||||
"data": {"stepIndex": 0},
|
||||
}
|
||||
|
||||
def test_consecutive_separators_collapse_to_one(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_delta("text-1", "x")
|
||||
b.on_step_separator()
|
||||
b.on_step_separator() # No-op: previous part is already a separator.
|
||||
snap = b.snapshot()
|
||||
assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1
|
||||
|
||||
def test_step_index_increments_across_separators(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_delta("text-1", "a")
|
||||
b.on_step_separator()
|
||||
b.on_text_delta("text-2", "b")
|
||||
b.on_step_separator()
|
||||
snap = b.snapshot()
|
||||
seps = [p for p in snap if p["type"] == "data-step-separator"]
|
||||
assert [s["data"]["stepIndex"] for s in seps] == [0, 1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Interruption handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarkInterrupted:
|
||||
def test_running_tool_calls_get_state_aborted(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_a", "ls", "lc_a")
|
||||
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
|
||||
# No tool-output-available — simulates client disconnect mid-tool.
|
||||
|
||||
b.mark_interrupted()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap[0]["state"] == "aborted"
|
||||
assert "result" not in snap[0]
|
||||
|
||||
def test_completed_tool_calls_are_not_marked_aborted(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_a", "ls", "lc_a")
|
||||
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
|
||||
b.on_tool_output_available("call_a", {"files": []}, "lc_a")
|
||||
|
||||
b.mark_interrupted()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert "state" not in snap[0]
|
||||
assert snap[0]["result"] == {"files": []}
|
||||
|
||||
def test_open_text_block_keeps_accumulated_content(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("text-1")
|
||||
b.on_text_delta("text-1", "partial")
|
||||
# No on_text_end — disconnect mid-stream.
|
||||
|
||||
b.mark_interrupted()
|
||||
|
||||
snap = b.snapshot()
|
||||
assert snap == [{"type": "text", "text": "partial"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_empty / snapshot semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsEmpty:
|
||||
def test_fresh_builder_is_empty(self):
|
||||
assert AssistantContentBuilder().is_empty()
|
||||
|
||||
def test_text_part_breaks_emptiness(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_delta("text-1", "x")
|
||||
assert not b.is_empty()
|
||||
|
||||
def test_tool_call_breaks_emptiness(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", None)
|
||||
assert not b.is_empty()
|
||||
|
||||
def test_thinking_step_alone_does_not_break_emptiness(self):
|
||||
# Mirrors the "status marker fallback" semantic: a turn that
|
||||
# only emitted a thinking step before being interrupted should
|
||||
# still be treated as empty for finalize_assistant_turn's
|
||||
# status-marker substitution.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("step-1", "Working", "in_progress", [])
|
||||
assert b.is_empty()
|
||||
|
||||
def test_step_separator_alone_does_not_break_emptiness(self):
|
||||
b = AssistantContentBuilder()
|
||||
# Force a separator (it would normally no-op without content,
|
||||
# but we simulate the underlying state to verify is_empty is
|
||||
# not fooled by a stray separator).
|
||||
b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}})
|
||||
assert b.is_empty()
|
||||
|
||||
|
||||
class TestSnapshotSemantics:
|
||||
def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
|
||||
snap = b.snapshot()
|
||||
# Mutate the returned snapshot — original should be untouched.
|
||||
snap[0]["args"]["mutated"] = True
|
||||
snap[0]["state"] = "tampered"
|
||||
|
||||
again = b.snapshot()
|
||||
assert "mutated" not in again[0]["args"]
|
||||
assert "state" not in again[0]
|
||||
|
||||
def test_snapshot_round_trips_through_json(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"])
|
||||
b.on_text_delta("text-1", "answer")
|
||||
b.on_tool_input_start("call_x", "ls", "lc_x")
|
||||
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
|
||||
b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x")
|
||||
b.on_step_separator()
|
||||
snap = b.snapshot()
|
||||
|
||||
encoded = json.dumps(snap)
|
||||
assert json.loads(encoded) == snap
|
||||
|
||||
|
||||
class TestStats:
|
||||
"""``stats()`` is the perf-log handle for [PERF] [stream_*]
|
||||
finalize_payload lines. Pin the schema so an ops dashboard can
|
||||
rely on these keys being present and meaningful.
|
||||
"""
|
||||
|
||||
def test_fresh_builder_reports_all_zeros(self):
|
||||
b = AssistantContentBuilder()
|
||||
s = b.stats()
|
||||
assert s == {
|
||||
"parts": 0,
|
||||
"bytes": 2, # ``[]`` is two bytes
|
||||
"text": 0,
|
||||
"reasoning": 0,
|
||||
"tool_calls": 0,
|
||||
"tool_calls_completed": 0,
|
||||
"tool_calls_aborted": 0,
|
||||
"thinking_step_parts": 0,
|
||||
"step_separators": 0,
|
||||
}
|
||||
|
||||
def test_counts_each_part_type_independently(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("t1")
|
||||
b.on_text_delta("t1", "hi")
|
||||
b.on_text_end("t1")
|
||||
b.on_reasoning_start("r1")
|
||||
b.on_reasoning_delta("r1", "thinking")
|
||||
b.on_reasoning_end("r1")
|
||||
b.on_thinking_step("step-1", "Analyzing", "completed", ["item"])
|
||||
b.on_step_separator()
|
||||
b.on_tool_input_start("call_done", "ls", "lc_done")
|
||||
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
|
||||
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
|
||||
b.on_tool_input_start("call_running", "rm", "lc_running")
|
||||
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
|
||||
|
||||
s = b.stats()
|
||||
assert s["text"] == 1
|
||||
assert s["reasoning"] == 1
|
||||
assert s["tool_calls"] == 2
|
||||
assert s["tool_calls_completed"] == 1
|
||||
assert s["tool_calls_aborted"] == 0
|
||||
assert s["thinking_step_parts"] == 1
|
||||
assert s["step_separators"] == 1
|
||||
assert s["parts"] == sum(
|
||||
[
|
||||
s["text"],
|
||||
s["reasoning"],
|
||||
s["tool_calls"],
|
||||
s["thinking_step_parts"],
|
||||
s["step_separators"],
|
||||
]
|
||||
)
|
||||
assert s["bytes"] > 0
|
||||
|
||||
def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self):
|
||||
b = AssistantContentBuilder()
|
||||
b.on_tool_input_start("call_done", "ls", "lc_done")
|
||||
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
|
||||
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
|
||||
b.on_tool_input_start("call_running", "rm", "lc_running")
|
||||
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
|
||||
|
||||
# Pre-interrupt: one completed, one still running (no result).
|
||||
pre = b.stats()
|
||||
assert pre["tool_calls_completed"] == 1
|
||||
assert pre["tool_calls_aborted"] == 0
|
||||
|
||||
b.mark_interrupted()
|
||||
post = b.stats()
|
||||
assert post["tool_calls_completed"] == 1
|
||||
assert post["tool_calls_aborted"] == 1
|
||||
assert post["tool_calls"] == 2
|
||||
|
||||
def test_bytes_reflects_jsonb_payload_size(self):
|
||||
# Each text-delta adds bytes monotonically — useful for catching
|
||||
# an unbounded delta buffer regression in the perf signal.
|
||||
b = AssistantContentBuilder()
|
||||
b.on_text_start("t1")
|
||||
b.on_text_delta("t1", "x" * 10)
|
||||
small = b.stats()["bytes"]
|
||||
b.on_text_delta("t1", "x" * 1000)
|
||||
large = b.stats()["bytes"]
|
||||
assert large > small + 900
|
||||
|
|
@ -51,22 +51,34 @@ class _FakeToolMessage:
|
|||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeInterrupt:
|
||||
value: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeTask:
|
||||
interrupts: tuple[_FakeInterrupt, ...] = ()
|
||||
|
||||
|
||||
class _FakeAgentState:
|
||||
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, tasks: list[Any] | None = None) -> None:
|
||||
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||
# and an empty ``tasks`` list keeps the post-stream interrupt
|
||||
# check a no-op too.
|
||||
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
|
||||
self.values: dict[str, Any] = {}
|
||||
self.tasks: list[Any] = []
|
||||
self.tasks: list[Any] = tasks or []
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Replays a list of ``astream_events`` events."""
|
||||
|
||||
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||
def __init__(
|
||||
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> None:
|
||||
self._events = events
|
||||
self._state = state or _FakeAgentState()
|
||||
|
||||
async def astream_events( # type: ignore[no-untyped-def]
|
||||
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||
|
|
@ -79,7 +91,7 @@ class _FakeAgent:
|
|||
# Called once after astream_events drains so the cloud-fallback
|
||||
# safety net can inspect staged filesystem work. The fake stays
|
||||
# empty so the safety net is a no-op.
|
||||
return _FakeAgentState()
|
||||
return self._state
|
||||
|
||||
|
||||
def _model_stream(
|
||||
|
|
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
async def _drain(
|
||||
events: list[dict[str, Any]], state: _FakeAgentState | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||
SSE payloads (parsed JSON) it yielded.
|
||||
"""
|
||||
agent = _FakeAgent(events)
|
||||
agent = _FakeAgent(events, state=state)
|
||||
service = VercelStreamingService()
|
||||
result = StreamResult()
|
||||
config = {"configurable": {"thread_id": "test-thread"}}
|
||||
|
|
@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id(
|
|||
assert len(starts) == 1
|
||||
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_request_uses_task_that_contains_interrupt(
|
||||
parity_v2_on: None,
|
||||
) -> None:
|
||||
interrupt_payload = {
|
||||
"type": "calendar_event_create",
|
||||
"action": {
|
||||
"tool": "create_calendar_event",
|
||||
"params": {"summary": "mom bday"},
|
||||
},
|
||||
"context": {},
|
||||
}
|
||||
state = _FakeAgentState(
|
||||
tasks=[
|
||||
_FakeTask(interrupts=()),
|
||||
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
|
||||
]
|
||||
)
|
||||
|
||||
payloads = await _drain([], state=state)
|
||||
|
||||
interrupts = _of_type(payloads, "data-interrupt-request")
|
||||
assert len(interrupts) == 1
|
||||
assert (
|
||||
interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
|
||||
)
|
||||
|
|
|
|||
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
318
surfsense_backend/tests/unit/tasks/test_celery_async_runner.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
"""Regression tests for ``run_async_celery_task``.
|
||||
|
||||
These tests pin down the production bug observed on 2026-05-02 where
|
||||
the video-presentation Celery task hung at ``[billable_call] finalize``
|
||||
because the shared ``app.db.engine`` had pooled asyncpg connections
|
||||
bound to a *previous* task's now-closed event loop. Reusing such a
|
||||
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
|
||||
|
||||
AttributeError: 'NoneType' object has no attribute 'send'
|
||||
|
||||
(the proactor is None because the loop is gone) and can hang forever
|
||||
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
|
||||
|
||||
The fix is ``run_async_celery_task``: a small helper that runs every
|
||||
async celery task body inside a fresh event loop and disposes the
|
||||
shared engine pool both before (defends against a previous task that
|
||||
crashed) and after (releases connections we opened on this loop).
|
||||
|
||||
Tests here exercise the helper with a stub engine that records
|
||||
``dispose()`` calls and panics if a coroutine produced by one loop is
|
||||
awaited on another — mirroring the real asyncpg behaviour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub engine that emulates the asyncpg-on-stale-loop crash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StaleLoopEngine:
|
||||
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
|
||||
|
||||
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
|
||||
the running event loop id so tests can assert it ran on *each*
|
||||
fresh loop.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.dispose_loop_ids: list[int] = []
|
||||
|
||||
async def dispose(self) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self.dispose_loop_ids.append(id(loop))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
|
||||
"""Patch ``from app.db import engine as shared_engine`` lookup.
|
||||
|
||||
The helper imports lazily inside the function body, so we have to
|
||||
patch the attribute on the already-loaded ``app.db`` module.
|
||||
"""
|
||||
import app.db as app_db
|
||||
|
||||
original = getattr(app_db, "engine", None)
|
||||
app_db.engine = stub # type: ignore[attr-defined]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if original is None:
|
||||
with pytest.raises(AttributeError):
|
||||
_ = app_db.engine
|
||||
else:
|
||||
app_db.engine = original # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
|
||||
"""Happy path: the coroutine result is returned, and the shared
|
||||
engine is disposed both before and after the task body runs.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> str:
|
||||
# Engine should already have been disposed once before we run.
|
||||
assert len(stub.dispose_loop_ids) == 1
|
||||
return "ok"
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
result = run_async_celery_task(_body)
|
||||
|
||||
assert result == "ok"
|
||||
# Once before the body, once after (in finally).
|
||||
assert len(stub.dispose_loop_ids) == 2
|
||||
# Both disposes ran on the SAME (fresh) loop the task body used.
|
||||
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
|
||||
|
||||
|
||||
def test_runner_creates_fresh_loop_per_invocation() -> None:
|
||||
"""Each call must spin its own loop. Without this guarantee a
|
||||
previous task's loop would be reused and the asyncpg-stale-loop
|
||||
crash would never be avoided.
|
||||
"""
|
||||
import app.tasks.celery_tasks as celery_tasks_pkg
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
new_loop_calls = 0
|
||||
closed_loops: list[bool] = []
|
||||
|
||||
real_new_event_loop = asyncio.new_event_loop
|
||||
|
||||
def _counting_new_loop() -> asyncio.AbstractEventLoop:
|
||||
nonlocal new_loop_calls
|
||||
new_loop_calls += 1
|
||||
loop = real_new_event_loop()
|
||||
# Hook close() so we can verify each loop was closed properly
|
||||
# before the next one was created.
|
||||
original_close = loop.close
|
||||
|
||||
def _tracked_close() -> None:
|
||||
closed_loops.append(True)
|
||||
original_close()
|
||||
|
||||
loop.close = _tracked_close # type: ignore[method-assign]
|
||||
return loop
|
||||
|
||||
async def _body() -> None:
|
||||
# Loop is alive and current at body execution time.
|
||||
running = asyncio.get_running_loop()
|
||||
assert not running.is_closed()
|
||||
|
||||
with (
|
||||
_patch_shared_engine(stub),
|
||||
patch.object(asyncio, "new_event_loop", _counting_new_loop),
|
||||
):
|
||||
for _ in range(3):
|
||||
celery_tasks_pkg.run_async_celery_task(_body)
|
||||
|
||||
assert new_loop_calls == 3
|
||||
assert closed_loops == [True, True, True]
|
||||
# Each invocation disposed twice (before + after).
|
||||
assert len(stub.dispose_loop_ids) == 6
|
||||
|
||||
|
||||
def test_runner_disposes_engine_even_when_body_raises() -> None:
|
||||
"""Cleanup MUST run on the failure path too — otherwise stale
|
||||
connections leak into the next task and cause the original hang.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
class _BoomError(RuntimeError):
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
raise _BoomError("kaboom")
|
||||
|
||||
with _patch_shared_engine(stub), pytest.raises(_BoomError):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
|
||||
|
||||
|
||||
def test_runner_swallows_dispose_errors() -> None:
|
||||
"""A flaky engine.dispose() must NEVER take down a celery task.
|
||||
|
||||
Production scenario: the very first dispose (before the body runs)
|
||||
might hit a partially-initialised engine; the helper logs and
|
||||
moves on. The task body still runs; the result is still returned.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
class _AngryEngine:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
|
||||
async def dispose(self) -> None:
|
||||
self.calls += 1
|
||||
raise RuntimeError("dispose() blew up")
|
||||
|
||||
stub = _AngryEngine()
|
||||
|
||||
async def _body() -> int:
|
||||
return 42
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
assert run_async_celery_task(_body) == 42
|
||||
|
||||
assert stub.calls == 2 # before + after both attempted
|
||||
|
||||
|
||||
def test_runner_propagates_value_from_async_body() -> None:
|
||||
"""Sanity: pass-through of any pickleable celery return value."""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _body() -> dict[str, object]:
|
||||
return {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
out = run_async_celery_task(_body)
|
||||
|
||||
assert out == {"status": "ready", "video_presentation_id": 19}
|
||||
|
||||
|
||||
def test_video_presentation_task_uses_runner_helper() -> None:
|
||||
"""Defence-in-depth: confirm the celery task module imports
|
||||
``run_async_celery_task``. If a future refactor inlines a
|
||||
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
|
||||
the original hang will return.
|
||||
"""
|
||||
# The module's task body should not contain a manual new_event_loop
|
||||
# call — that's exactly what the helper exists to centralise.
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
src = inspect.getsource(video_presentation_tasks)
|
||||
assert "run_async_celery_task" in src, (
|
||||
"video_presentation_tasks.py must use run_async_celery_task; "
|
||||
"manual asyncio.new_event_loop() in a celery task hangs on the "
|
||||
"shared SQLAlchemy pool when reused across tasks."
|
||||
)
|
||||
assert "asyncio.new_event_loop" not in src, (
|
||||
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
|
||||
"call — route every async task through run_async_celery_task to "
|
||||
"avoid the stale-pool hang."
|
||||
)
|
||||
|
||||
|
||||
def test_podcast_task_uses_runner_helper() -> None:
|
||||
"""Symmetric assertion for the podcast task — same root cause, same
|
||||
fix, same regression risk.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
src = inspect.getsource(podcast_tasks)
|
||||
assert "run_async_celery_task" in src
|
||||
assert "asyncio.new_event_loop" not in src
|
||||
|
||||
|
||||
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
|
||||
"""If the task body created any async generators that didn't get
|
||||
fully iterated, we must still call ``loop.shutdown_asyncgens()``
|
||||
before closing — otherwise we leak event-loop bound resources
|
||||
that re-emerge as ``RuntimeError: Event loop is closed`` later.
|
||||
"""
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
async def _agen():
|
||||
try:
|
||||
yield 1
|
||||
yield 2
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _body() -> None:
|
||||
# Iterate the agen partially, then leave it dangling — exactly
|
||||
# the situation shutdown_asyncgens() is designed to clean up.
|
||||
async for v in _agen():
|
||||
if v == 1:
|
||||
break
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# By the time the helper returns, garbage collection + shutdown_asyncgens
|
||||
# should have ensured no live async-gen references remain. We don't
|
||||
# assert agen.closed directly (it depends on GC ordering); the real
|
||||
# contract is "no warnings, no event-loop-closed errors". A successful
|
||||
# second invocation proves the loop was cleaned up properly.
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
# Force a GC pass to surface any 'coroutine was never awaited'
|
||||
# warnings that would indicate the cleanup is broken.
|
||||
gc.collect()
|
||||
|
||||
|
||||
def test_runner_uses_proactor_loop_on_windows() -> None:
|
||||
"""On Windows the celery worker preselects a Proactor policy so
|
||||
subprocess (ffmpeg) calls work. The helper must not silently fall
|
||||
back to a Selector loop and re-break video/podcast generation.
|
||||
"""
|
||||
if not sys.platform.startswith("win"):
|
||||
pytest.skip("Windows-specific event-loop policy assertion")
|
||||
|
||||
from app.tasks.celery_tasks import run_async_celery_task
|
||||
|
||||
stub = _StaleLoopEngine()
|
||||
|
||||
# Mirror the policy set at the top of every Windows celery task.
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
observed: list[str] = []
|
||||
|
||||
async def _body() -> None:
|
||||
observed.append(type(asyncio.get_running_loop()).__name__)
|
||||
|
||||
with _patch_shared_engine(stub):
|
||||
run_async_celery_task(_body)
|
||||
|
||||
assert observed == ["ProactorEventLoop"]
|
||||
388
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
388
surfsense_backend/tests/unit/tasks/test_podcast_billing.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
"""Unit tests for podcast Celery task billing integration.
|
||||
|
||||
Validates ``_generate_content_podcast`` correctly wraps
|
||||
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
|
||||
search-space owner's billing decision, and degrades cleanly when the
|
||||
resolver fails or premium credit is exhausted.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Happy-path free config: resolver → ``billable_call`` enters with
|
||||
``usage_type='podcast_generation'`` and the configured reserve override,
|
||||
graph runs, podcast row flips to ``READY``.
|
||||
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError`` →
|
||||
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
|
||||
carries ``reason='premium_quota_exhausted'``.
|
||||
* Resolver failure: ``ValueError`` from the resolver → podcast row flips
|
||||
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, podcast):
|
||||
self._podcast = podcast
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._podcast)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
|
||||
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
|
||||
inside helpers keeps this fixture cheap."""
|
||||
return SimpleNamespace(
|
||||
id=podcast_id,
|
||||
title="Test Podcast",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
podcast_transcript=None,
|
||||
file_location=None,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
"""Stand-in for ``billable_call`` that records its kwargs and yields a
|
||||
no-op accumulator-shaped object."""
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover — for grammar only
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
"""Happy path: free billing tier still wraps the graph call so the
|
||||
audit row is recorded. Verifies kwargs threading."""
|
||||
from app.config import config as app_config
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=7, thread_id=99)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 555
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {
|
||||
"podcast_transcript": [
|
||||
SimpleNamespace(speaker_id=0, dialog="Hi"),
|
||||
SimpleNamespace(speaker_id=1, dialog="Hello"),
|
||||
],
|
||||
"final_podcast_file_path": "/tmp/podcast.wav",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hello world",
|
||||
search_space_id=555,
|
||||
user_prompt="make it short",
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["podcast_id"] == 7
|
||||
assert podcast.status == PodcastStatus.READY
|
||||
assert podcast.file_location == "/tmp/podcast.wav"
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 555
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "podcast_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
|
||||
)
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"podcast_id": 7,
|
||||
"title": "Test Podcast",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
"""Premium resolution flows through to ``billable_call`` so the
|
||||
reserve/finalize path triggers."""
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast()
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=7,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
|
||||
"""When ``billable_call`` denies the reservation, the graph never
|
||||
runs and the podcast row flips to FAILED with the documented reason
|
||||
code."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=8)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=8,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 8,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == [] # Graph never ran on denied reservation.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=10)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "billable_call", _settlement_failing_billable_call
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=10,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 10,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
|
||||
"""If the resolver raises (e.g. search-space deleted), the task fails
|
||||
cleanly without invoking the graph."""
|
||||
from app.db import PodcastStatus
|
||||
from app.tasks.celery_tasks import podcast_tasks
|
||||
|
||||
podcast = _make_podcast(podcast_id=9)
|
||||
session = _FakeSession(podcast)
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 555 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
|
||||
|
||||
result = await podcast_tasks._generate_content_podcast(
|
||||
podcast_id=9,
|
||||
source_content="hi",
|
||||
search_space_id=555,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"podcast_id": 9,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert podcast.status == PodcastStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
"""Predicate-level test for the chat streaming safety net.
|
||||
|
||||
The safety net in ``stream_new_chat`` rejects an image turn early with
|
||||
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
|
||||
selected model is *known* to be text-only. The earlier round of this
|
||||
work used a strict opt-in flag (``supports_image_input`` defaulting to
|
||||
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
|
||||
deployments — this is the regression we're fixing.
|
||||
|
||||
The new predicate is :func:`is_known_text_only_chat_model`, which
|
||||
returns True only when LiteLLM's authoritative model map *explicitly*
|
||||
sets ``supports_vision=False``. Anything else (vision True, missing
|
||||
key, exception) returns False so the request flows through to the
|
||||
provider.
|
||||
|
||||
We exercise the predicate directly here rather than driving the full
|
||||
``stream_new_chat`` generator — covering the gate in isolation keeps
|
||||
the test focused on the regression while the generator's wider behavior
|
||||
is exercised by the integration suite.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.provider_capabilities import is_known_text_only_chat_model
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_azure_gpt_4o():
|
||||
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
|
||||
vision-capable per LiteLLM's model map. The previous round's
|
||||
blanket-False default blocked it; the new predicate must NOT mark
|
||||
it text-only."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="my-azure-deployment",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_for_unknown_model():
|
||||
"""Default-pass on unknown — the safety net only blocks definitive
|
||||
text-only confirmations. A freshly added third-party model that
|
||||
LiteLLM doesn't know about must flow through to the provider."""
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="CUSTOM",
|
||||
custom_provider="brand_new_proxy",
|
||||
model_name="brand-new-model-x9",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
|
||||
"""Transient ``litellm.get_model_info`` exception ≠ block. The
|
||||
helper swallows the error and treats it as 'unknown' → False."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _raise(**_kwargs):
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
|
||||
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="gpt-4o",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
|
||||
"""Stub LiteLLM to assert the only path that returns True is the
|
||||
explicit ``supports_vision=False`` case. Anything else (True,
|
||||
None, missing key) returns False from the predicate."""
|
||||
import app.services.provider_capabilities as pc
|
||||
|
||||
def _info_explicit_false(**_kwargs):
|
||||
return {"supports_vision": False, "max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="text-only-stub",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def _info_true(**_kwargs):
|
||||
return {"supports_vision": True}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="vision-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def _info_missing(**_kwargs):
|
||||
return {"max_input_tokens": 8192}
|
||||
|
||||
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
|
||||
assert (
|
||||
is_known_text_only_chat_model(
|
||||
provider="OPENAI",
|
||||
model_name="missing-key-stub",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
|
@ -0,0 +1,398 @@
|
|||
"""Unit tests for video-presentation Celery task billing integration.
|
||||
|
||||
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
|
||||
Validates the same wrap-graph-in-billable_call pattern and ensures the
|
||||
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
|
||||
threaded through.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Free config: graph runs, ``billable_call`` invoked with the video
|
||||
reserve override.
|
||||
* Premium config: same wiring with ``billing_tier='premium'``.
|
||||
* Quota denial: graph not invoked, row → FAILED, reason code surfaced.
|
||||
* Resolver failure: row → FAILED with ``billing_resolution_failed``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, video):
|
||||
self._video = video
|
||||
self.commit_count = 0
|
||||
|
||||
async def execute(self, _stmt):
|
||||
return _FakeExecResult(self._video)
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSessionMaker:
|
||||
def __init__(self, session: _FakeSession):
|
||||
self._session = session
|
||||
|
||||
def __call__(self):
|
||||
return self._session
|
||||
|
||||
|
||||
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=video_id,
|
||||
title="Test Presentation",
|
||||
thread_id=thread_id,
|
||||
status=None,
|
||||
slides=None,
|
||||
scene_codes=None,
|
||||
)
|
||||
|
||||
|
||||
_CALL_LOG: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ok_billable_call(**kwargs):
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**kwargs):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
raise QuotaInsufficientError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield SimpleNamespace() # pragma: no cover
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _settlement_failing_billable_call(**kwargs):
|
||||
from app.services.billable_calls import BillingSettlementError
|
||||
|
||||
_CALL_LOG.append(kwargs)
|
||||
yield SimpleNamespace()
|
||||
raise BillingSettlementError(
|
||||
usage_type=kwargs.get("usage_type", "?"),
|
||||
user_id=kwargs["user_id"],
|
||||
cause=RuntimeError("finalize failed"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_call_log():
|
||||
_CALL_LOG.clear()
|
||||
yield
|
||||
_CALL_LOG.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
|
||||
from app.config import config as app_config
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=11, thread_id=99)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
assert search_space_id == 777
|
||||
assert thread_id == 99
|
||||
return user_id, "free", "openrouter/some-free-model"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=11,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert result["video_presentation_id"] == 11
|
||||
assert video.status == VideoPresentationStatus.READY
|
||||
|
||||
assert len(_CALL_LOG) == 1
|
||||
call = _CALL_LOG[0]
|
||||
assert call["user_id"] == user_id
|
||||
assert call["search_space_id"] == 777
|
||||
assert call["billing_tier"] == "free"
|
||||
assert call["base_model"] == "openrouter/some-free-model"
|
||||
assert call["usage_type"] == "video_presentation_generation"
|
||||
assert (
|
||||
call["quota_reserve_micros_override"]
|
||||
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
|
||||
)
|
||||
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
|
||||
# FK to avoid coupling Celery audit commits to an active chat transaction.
|
||||
assert "thread_id" not in call
|
||||
assert call["call_details"] == {
|
||||
"video_presentation_id": 11,
|
||||
"title": "Test Presentation",
|
||||
"thread_id": 99,
|
||||
}
|
||||
assert callable(call["billable_session_factory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video()
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return user_id, "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=11,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert _CALL_LOG[0]["billing_tier"] == "premium"
|
||||
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=12)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks, "billable_call", _denying_billable_call
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=12,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 12,
|
||||
"reason": "premium_quota_exhausted",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=14)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
|
||||
return uuid4(), "premium", "gpt-5.4"
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_fake_resolver,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"billable_call",
|
||||
_settlement_failing_billable_call,
|
||||
)
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=14,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 14,
|
||||
"reason": "billing_settlement_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolver_failure_marks_video_failed(monkeypatch):
|
||||
from app.db import VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import video_presentation_tasks
|
||||
|
||||
video = _make_video(video_id=13)
|
||||
session = _FakeSession(video)
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"get_celery_session_maker",
|
||||
lambda: _FakeSessionMaker(session),
|
||||
)
|
||||
|
||||
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
|
||||
raise ValueError("Search space 777 not found")
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks,
|
||||
"_resolve_agent_billing_for_search_space",
|
||||
_failing_resolver,
|
||||
)
|
||||
|
||||
graph_invoked = []
|
||||
|
||||
async def _fake_graph_invoke(state, config):
|
||||
graph_invoked.append(True)
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
video_presentation_tasks.video_presentation_graph,
|
||||
"ainvoke",
|
||||
_fake_graph_invoke,
|
||||
)
|
||||
|
||||
result = await video_presentation_tasks._generate_video_presentation(
|
||||
video_presentation_id=13,
|
||||
source_content="content",
|
||||
search_space_id=777,
|
||||
user_prompt=None,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"status": "failed",
|
||||
"video_presentation_id": 13,
|
||||
"reason": "billing_resolution_failed",
|
||||
}
|
||||
assert video.status == VideoPresentationStatus.FAILED
|
||||
assert graph_invoked == []
|
||||
|
|
@ -1,9 +1,21 @@
|
|||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||
from app.tasks.chat.stream_new_chat import (
|
||||
StreamResult,
|
||||
_classify_stream_exception,
|
||||
_contract_enforcement_active,
|
||||
_evaluate_file_contract_outcome,
|
||||
_extract_resolved_file_path,
|
||||
_log_chat_stream_error,
|
||||
_tool_output_has_error,
|
||||
)
|
||||
|
||||
|
|
@ -17,6 +29,39 @@ def test_tool_output_error_detection():
|
|||
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
||||
|
||||
|
||||
def test_extract_resolved_file_path_prefers_structured_path():
|
||||
assert (
|
||||
_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={"status": "completed", "path": "/docs/note.md"},
|
||||
tool_input=None,
|
||||
)
|
||||
== "/docs/note.md"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_resolved_file_path_falls_back_to_tool_input():
|
||||
assert (
|
||||
_extract_resolved_file_path(
|
||||
tool_name="edit_file",
|
||||
tool_output={"status": "completed", "result": "updated"},
|
||||
tool_input={"file_path": "/docs/edited.md"},
|
||||
)
|
||||
== "/docs/edited.md"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_resolved_file_path_does_not_parse_result_text():
|
||||
assert (
|
||||
_extract_resolved_file_path(
|
||||
tool_name="write_file",
|
||||
tool_output={"result": "Updated file /docs/from-text.md"},
|
||||
tool_input=None,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_file_write_contract_outcome_reasons():
|
||||
result = StreamResult(intent_detected="file_write")
|
||||
passed, reason = _evaluate_file_contract_outcome(result)
|
||||
|
|
@ -45,3 +90,507 @@ def test_contract_enforcement_local_only():
|
|||
|
||||
result.filesystem_mode = "cloud"
|
||||
assert not _contract_enforcement_active(result)
|
||||
|
||||
|
||||
def _extract_chat_stream_payload(record_message: str) -> dict:
|
||||
prefix = "[chat_stream_error] "
|
||||
assert record_message.startswith(prefix)
|
||||
return json.loads(record_message[len(prefix) :])
|
||||
|
||||
|
||||
def test_unified_chat_stream_error_log_schema(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
||||
_log_chat_stream_error(
|
||||
flow="new",
|
||||
error_kind="server_error",
|
||||
error_code="SERVER_ERROR",
|
||||
severity="warn",
|
||||
is_expected=False,
|
||||
request_id="req-123",
|
||||
thread_id=101,
|
||||
search_space_id=202,
|
||||
user_id="user-1",
|
||||
message="Error during chat: boom",
|
||||
)
|
||||
|
||||
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
||||
payload = _extract_chat_stream_payload(record.message)
|
||||
|
||||
required_keys = {
|
||||
"event",
|
||||
"flow",
|
||||
"error_kind",
|
||||
"error_code",
|
||||
"severity",
|
||||
"is_expected",
|
||||
"request_id",
|
||||
"thread_id",
|
||||
"search_space_id",
|
||||
"user_id",
|
||||
"message",
|
||||
}
|
||||
assert required_keys.issubset(payload.keys())
|
||||
assert payload["event"] == "chat_stream_error"
|
||||
assert payload["flow"] == "new"
|
||||
assert payload["error_code"] == "SERVER_ERROR"
|
||||
|
||||
|
||||
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
||||
_log_chat_stream_error(
|
||||
flow="resume",
|
||||
error_kind="premium_quota_exhausted",
|
||||
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||
severity="info",
|
||||
is_expected=True,
|
||||
request_id="req-premium",
|
||||
thread_id=303,
|
||||
search_space_id=404,
|
||||
user_id="user-2",
|
||||
message="Buy more tokens to continue with this model, or switch to a free model",
|
||||
extra={"auto_fallback": False},
|
||||
)
|
||||
|
||||
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
||||
payload = _extract_chat_stream_payload(record.message)
|
||||
assert payload["event"] == "chat_stream_error"
|
||||
assert payload["error_kind"] == "premium_quota_exhausted"
|
||||
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
|
||||
assert payload["flow"] == "resume"
|
||||
assert payload["is_expected"] is True
|
||||
assert payload["auto_fallback"] is False
|
||||
|
||||
|
||||
def test_stream_error_emission_keeps_machine_error_codes():
|
||||
source = inspect.getsource(stream_new_chat_module)
|
||||
format_error_calls = re.findall(r"format_error\(", source)
|
||||
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
|
||||
|
||||
# All stream paths should route through one shared terminal error emitter.
|
||||
assert len(format_error_calls) == 1
|
||||
assert {
|
||||
"PREMIUM_QUOTA_EXHAUSTED",
|
||||
"SERVER_ERROR",
|
||||
}.issubset(emitted_error_codes)
|
||||
assert 'flow: Literal["new", "regenerate"] = "new"' in source
|
||||
assert "_emit_stream_terminal_error" in source
|
||||
assert "flow=flow" in source
|
||||
assert 'flow="resume"' in source
|
||||
|
||||
|
||||
def test_stream_exception_classifies_rate_limited():
|
||||
exc = Exception(
|
||||
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
assert code == "RATE_LIMITED"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_openrouter_429_payload():
|
||||
exc = Exception(
|
||||
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
|
||||
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
|
||||
)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "rate_limited"
|
||||
assert code == "RATE_LIMITED"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "temporarily rate-limited" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
|
||||
"""``_preflight_llm`` is best-effort.
|
||||
|
||||
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
|
||||
caller can drive the cooldown/repin branch.
|
||||
- On any other transient failure it MUST swallow the error so the normal
|
||||
stream path continues without surfacing preflight noise to the user.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
||||
|
||||
class _RateLimitedError(Exception):
|
||||
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
|
||||
|
||||
rate_calls: list[dict] = []
|
||||
other_calls: list[dict] = []
|
||||
|
||||
async def _fake_acompletion_429(**kwargs):
|
||||
rate_calls.append(kwargs)
|
||||
raise _RateLimitedError("simulated 429")
|
||||
|
||||
async def _fake_acompletion_other(**kwargs):
|
||||
other_calls.append(kwargs)
|
||||
raise RuntimeError("some unrelated transient failure")
|
||||
|
||||
fake_llm = SimpleNamespace(
|
||||
model="openrouter/google/gemma-4-31b-it:free",
|
||||
api_key="test",
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
import litellm # type: ignore[import-not-found]
|
||||
|
||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
|
||||
with pytest.raises(_RateLimitedError):
|
||||
await _preflight_llm(fake_llm)
|
||||
assert len(rate_calls) == 1
|
||||
assert rate_calls[0]["max_tokens"] == 1
|
||||
assert rate_calls[0]["stream"] is False
|
||||
|
||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
|
||||
# MUST NOT raise: non-rate-limit failures are swallowed.
|
||||
await _preflight_llm(fake_llm)
|
||||
assert len(other_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_skipped_for_auto_router_model():
|
||||
"""Router-mode ``model='auto'`` has no single deployment to ping; the
|
||||
LiteLLM router itself owns per-deployment rate-limit accounting, so the
|
||||
preflight helper must short-circuit instead of issuing a probe."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _preflight_llm
|
||||
|
||||
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
|
||||
# Should return without raising or making any LiteLLM call.
|
||||
await _preflight_llm(fake_llm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settle_speculative_agent_build_swallows_exceptions():
|
||||
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
|
||||
caller can safely re-touch the request-scoped session afterwards.
|
||||
|
||||
The helper guards the parallel preflight + agent-build path: when the
|
||||
speculative build is being discarded (429 or non-429 preflight failure)
|
||||
we await it solely to release any in-flight ``AsyncSession`` usage —
|
||||
the build's outcome is irrelevant. Any exception (including
|
||||
``CancelledError``) leaking out would skip the caller's recovery flow
|
||||
and re-introduce the very session-concurrency hazard the helper exists
|
||||
to prevent.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
||||
|
||||
async def _raises() -> None:
|
||||
raise RuntimeError("speculative build crashed")
|
||||
|
||||
async def _succeeds() -> str:
|
||||
return "agent"
|
||||
|
||||
async def _slow() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
for coro in (_raises(), _succeeds(), _slow()):
|
||||
task = asyncio.create_task(coro)
|
||||
await _settle_speculative_agent_build(task)
|
||||
assert task.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settle_speculative_agent_build_handles_already_done_task():
|
||||
"""Done tasks (success or failure) must still be settled without raising."""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
|
||||
|
||||
async def _ok() -> str:
|
||||
return "ok"
|
||||
|
||||
async def _bad() -> None:
|
||||
raise ValueError("nope")
|
||||
|
||||
ok_task = asyncio.create_task(_ok())
|
||||
bad_task = asyncio.create_task(_bad())
|
||||
# Drive both to completion before settling.
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await _settle_speculative_agent_build(ok_task)
|
||||
await _settle_speculative_agent_build(bad_task)
|
||||
assert ok_task.result() == "ok"
|
||||
# ``bad_task`` exception was consumed by the settle helper; calling
|
||||
# ``.exception()`` after the fact must still return the original error
|
||||
# (the helper observes it but doesn't clear it).
|
||||
assert isinstance(bad_task.exception(), ValueError)
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy():
|
||||
exc = BusyError(request_id="thread-123")
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "THREAD_BUSY"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_thread_busy_from_message():
|
||||
exc = Exception("Thread is busy with another request")
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "THREAD_BUSY"
|
||||
assert severity == "warn"
|
||||
assert is_expected is True
|
||||
assert "still finishing for this thread" in user_message
|
||||
assert extra is None
|
||||
|
||||
|
||||
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
||||
thread_id = "thread-cancelling-1"
|
||||
reset_cancel(thread_id)
|
||||
request_cancel(thread_id)
|
||||
exc = BusyError(request_id=thread_id)
|
||||
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||
exc, flow_label="chat"
|
||||
)
|
||||
assert kind == "thread_busy"
|
||||
assert code == "TURN_CANCELLING"
|
||||
assert severity == "info"
|
||||
assert is_expected is True
|
||||
assert "stopping" in user_message
|
||||
assert isinstance(extra, dict)
|
||||
assert "retry_after_ms" in extra
|
||||
|
||||
|
||||
def test_premium_classification_is_error_code_driven():
|
||||
classifier_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||
)
|
||||
source = classifier_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "PREMIUM_KEYWORDS" not in source
|
||||
assert "RATE_LIMIT_KEYWORDS" not in source
|
||||
assert "normalized.includes(" not in source
|
||||
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
|
||||
|
||||
|
||||
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
)
|
||||
source = page_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "onPreAcceptFailure?: () => Promise<void>;" in source
|
||||
assert "if (!accepted) {" in source
|
||||
assert "await onPreAcceptFailure?.();" in source
|
||||
assert "await onAcceptedStreamError?.();" in source
|
||||
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
|
||||
assert "setMessageDocumentsMap((prev) => {" in source
|
||||
|
||||
|
||||
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
||||
user_message_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/components/assistant-ui/user-message.tsx"
|
||||
)
|
||||
source = user_message_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "Not sent. Edit and retry." not in source
|
||||
assert "failed_pre_accept" not in source
|
||||
|
||||
|
||||
def test_network_send_failures_use_unified_retry_toast_message():
|
||||
classifier_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||
)
|
||||
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||
request_errors_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/lib/chat/chat-request-errors.ts"
|
||||
)
|
||||
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '"send_failed_pre_accept"' in classifier_source
|
||||
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
||||
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||
assert "const passthroughCodes = new Set([" in request_errors_source
|
||||
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
||||
assert '"THREAD_BUSY"' in request_errors_source
|
||||
assert '"TURN_CANCELLING"' in request_errors_source
|
||||
assert '"AUTH_EXPIRED"' in request_errors_source
|
||||
assert '"UNAUTHORIZED"' in request_errors_source
|
||||
assert '"RATE_LIMITED"' in request_errors_source
|
||||
assert '"NETWORK_ERROR"' in request_errors_source
|
||||
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
||||
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
||||
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
||||
assert '"SERVER_ERROR"' in request_errors_source
|
||||
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
||||
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
||||
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
||||
assert "Failed to start chat. Please try again." not in classifier_source
|
||||
|
||||
|
||||
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||
page_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||
)
|
||||
source = page_path.read_text(encoding="utf-8")
|
||||
|
||||
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
||||
# The acceptance boundary is still meaningful post-refactor: it gates
|
||||
# local-state cleanup (onPreAcceptFailure path) and lets the shared
|
||||
# terminal handler distinguish pre-accept aborts from in-stream errors.
|
||||
assert "let newAccepted = false;" in source
|
||||
assert "let resumeAccepted = false;" in source
|
||||
assert "let regenerateAccepted = false;" in source
|
||||
assert "accepted: newAccepted," in source
|
||||
assert "accepted: resumeAccepted," in source
|
||||
assert "accepted: regenerateAccepted," in source
|
||||
|
||||
# NOTE: The FE-side persistence guards previously asserted here
|
||||
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
|
||||
# "if (newAccepted && !userPersisted) {") have been intentionally
|
||||
# removed by the SSE-based message-id handshake refactor. Persistence
|
||||
# is now server-authoritative: persist_user_turn / persist_assistant_shell
|
||||
# run inside stream_new_chat / stream_resume_chat unconditionally and
|
||||
# the FE consumes data-user-message-id / data-assistant-message-id
|
||||
# SSE events to learn the canonical primary keys. There is therefore
|
||||
# no FE call-site to guard, and the shared terminal handler relies
|
||||
# purely on the `accepted` field above (forwarded to onAbort /
|
||||
# onAcceptedStreamError) to drive UI cleanup. See
|
||||
# tests/integration/chat/test_message_id_sse.py for the new
|
||||
# cross-tier ID coherence guarantees.
|
||||
|
||||
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
|
||||
# of the persistence refactor and must still exist on every
|
||||
# start-stream fetch.
|
||||
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
||||
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
||||
|
||||
|
||||
def test_cancel_active_turn_route_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
||||
assert "response_model=CancelActiveTurnResponse" in source
|
||||
assert 'status="cancelling",' in source
|
||||
assert 'error_code="TURN_CANCELLING",' in source
|
||||
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
||||
assert "retry_after_at=" in source
|
||||
assert 'status="idle",' in source
|
||||
assert 'error_code="NO_ACTIVE_TURN",' in source
|
||||
|
||||
|
||||
def test_turn_status_route_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
||||
assert "response_model=TurnStatusResponse" in source
|
||||
assert "_build_turn_status_payload(thread_id)" in source
|
||||
assert "Permission.CHATS_READ.value" in source
|
||||
assert "_raise_if_thread_busy_for_start(" in source
|
||||
|
||||
|
||||
def test_turn_cancelling_retry_policy_contract_exists():
|
||||
routes_path = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||
)
|
||||
source = routes_path.read_text(encoding="utf-8")
|
||||
|
||||
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
||||
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
||||
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
||||
assert "def _compute_turn_cancelling_retry_delay(" in source
|
||||
assert "retry-after-ms" in source
|
||||
assert '"Retry-After"' in source
|
||||
assert '"errorCode": "TURN_CANCELLING"' in source
|
||||
|
||||
|
||||
def test_turn_status_sse_contract_exists():
|
||||
stream_source = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
||||
).read_text(encoding="utf-8")
|
||||
state_source = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/lib/chat/streaming-state.ts"
|
||||
).read_text(encoding="utf-8")
|
||||
pipeline_source = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "surfsense_web/lib/chat/stream-pipeline.ts"
|
||||
).read_text(encoding="utf-8")
|
||||
|
||||
assert '"turn-status"' in stream_source
|
||||
assert '"status": "busy"' in stream_source
|
||||
assert '"status": "idle"' in stream_source
|
||||
assert 'type: "data-turn-status"' in state_source
|
||||
assert 'case "data-turn-status":' in pipeline_source
|
||||
assert "end_turn(str(chat_id))" in stream_source
|
||||
|
||||
|
||||
def test_chat_deepagent_forwards_resolved_model_name_to_both_builders():
|
||||
"""Regression guard: both system-prompt builders in chat_deepagent.py
|
||||
must receive ``model_name=_resolve_prompt_model_name(...)`` so the
|
||||
provider-variant dispatch can render the right ``<provider_hints>``
|
||||
block. Without this the prompt silently falls back to the empty
|
||||
``"default"`` variant — the original bug being fixed.
|
||||
|
||||
This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes`
|
||||
in style: it inspects module source text + a regex to enforce the
|
||||
call-site shape, not just the wrapper layer (the wrappers already
|
||||
forward ``model_name`` correctly, so testing them would not catch
|
||||
the actual missed plumbing).
|
||||
"""
|
||||
import app.agents.new_chat.chat_deepagent as chat_deepagent_module
|
||||
|
||||
source = inspect.getsource(chat_deepagent_module)
|
||||
|
||||
# Helper itself must be defined.
|
||||
assert "def _resolve_prompt_model_name(" in source
|
||||
|
||||
# Both builder calls must forward the resolved model name. Match
|
||||
# across newlines + whitespace because the kwargs are split over
|
||||
# multiple lines.
|
||||
pattern = re.compile(
|
||||
r"build_(?:surfsense|configurable)_system_prompt\([^)]*"
|
||||
r"model_name=_resolve_prompt_model_name\(",
|
||||
re.DOTALL,
|
||||
)
|
||||
matches = pattern.findall(source)
|
||||
assert len(matches) == 2, (
|
||||
"Expected both system-prompt builder call sites to forward "
|
||||
"`model_name=_resolve_prompt_model_name(...)`, found "
|
||||
f"{len(matches)}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue