Merge upstream/dev into feature/multi-agent

This commit is contained in:
CREDO23 2026-05-05 01:44:46 +02:00
commit 5119915f4f
278 changed files with 34669 additions and 8970 deletions

View file

@ -226,6 +226,31 @@ class TestCompose:
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_provider_hints_render_with_custom_system_instructions(
self, fixed_today: datetime
) -> None:
"""Regression guard for the always-append decision: provider hints
append AFTER a custom system prompt.
Provider hints are stylistic nudges (parallel tool-call rules,
formatting guidance, etc.) that help the model regardless of
what the system instructions say. Suppressing them when a
custom prompt is set would partially defeat the per-family
prompt machinery.
"""
prompt = compose_system_prompt(
today=fixed_today,
custom_system_instructions="You are a custom assistant.",
model_name="anthropic/claude-3-5-sonnet",
)
assert "You are a custom assistant." in prompt
assert "<provider_hints>" in prompt
# The custom prompt must come BEFORE the provider hints so the
# user's framing isn't drowned out by the stylistic nudges.
assert prompt.index("You are a custom assistant.") < prompt.index(
"<provider_hints>"
)
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:

View file

@ -0,0 +1,268 @@
"""Regression tests for the compiled-agent cache.
Covers the cache primitive itself (TTL, LRU, in-flight de-duplication,
build-failure non-caching) and the cache-key signature helpers that
``create_surfsense_deep_agent`` relies on. The integration with
``create_surfsense_deep_agent`` is covered separately by the streaming
contract tests; this module focuses on the primitives so a regression
in the cache implementation is caught before it reaches the agent
factory.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import pytest
from app.agents.new_chat.agent_cache import (
flags_signature,
reload_for_tests,
stable_hash,
system_prompt_hash,
tools_signature,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# stable_hash + signature helpers
# ---------------------------------------------------------------------------
def test_stable_hash_is_deterministic_across_calls() -> None:
a = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
b = stable_hash("v1", 42, "thread-9", None, ["x", "y"])
assert a == b
def test_stable_hash_changes_when_any_part_changes() -> None:
base = stable_hash("v1", 42, "thread-9")
assert stable_hash("v1", 42, "thread-10") != base
assert stable_hash("v2", 42, "thread-9") != base
assert stable_hash("v1", 43, "thread-9") != base
def test_tools_signature_keys_on_name_and_description_not_identity() -> None:
"""Two tool lists with the same surface must hash identically.
The cache key MUST NOT change when the underlying ``BaseTool``
instances are different Python objects (a fresh request constructs
fresh tool instances every time). Hashing on ``(name, description)``
keeps the cache hot across requests with identical tool surfaces.
"""
@dataclass
class FakeTool:
name: str
description: str
tools_a = [FakeTool("alpha", "does alpha"), FakeTool("beta", "does beta")]
tools_b = [FakeTool("beta", "does beta"), FakeTool("alpha", "does alpha")]
sig_a = tools_signature(
tools_a, available_connectors=["NOTION"], available_document_types=["FILE"]
)
sig_b = tools_signature(
tools_b, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_a == sig_b, "tool order must not affect the signature"
# Adding a tool rotates the key.
tools_c = [*tools_a, FakeTool("gamma", "does gamma")]
sig_c = tools_signature(
tools_c, available_connectors=["NOTION"], available_document_types=["FILE"]
)
assert sig_c != sig_a
def test_tools_signature_rotates_when_connector_set_changes() -> None:
@dataclass
class FakeTool:
name: str
description: str
tools = [FakeTool("a", "x")]
base = tools_signature(
tools, available_connectors=["NOTION"], available_document_types=["FILE"]
)
added = tools_signature(
tools,
available_connectors=["NOTION", "SLACK"],
available_document_types=["FILE"],
)
assert base != added, "adding a connector must rotate the cache key"
def test_flags_signature_changes_when_flag_flips() -> None:
@dataclass(frozen=True)
class Flags:
a: bool = True
b: bool = False
base = flags_signature(Flags())
flipped = flags_signature(Flags(b=True))
assert base != flipped
def test_system_prompt_hash_is_stable_and_distinct() -> None:
p1 = "You are a helpful assistant."
p2 = "You are a helpful assistant!" # one-character delta
assert system_prompt_hash(p1) == system_prompt_hash(p1)
assert system_prompt_hash(p1) != system_prompt_hash(p2)
# ---------------------------------------------------------------------------
# _AgentCache: hit / miss / TTL / LRU / coalescing / failure-not-cached
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cache_hit_returns_same_instance_on_second_call() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is b, "cache must return the SAME object across hits"
assert builds == 1, "builder must run exactly once"
@pytest.mark.asyncio
async def test_cache_different_keys_get_different_instances() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k1", builder=builder)
b = await cache.get_or_build("k2", builder=builder)
assert a is not b
@pytest.mark.asyncio
async def test_cache_stale_entries_get_rebuilt() -> None:
# ttl=0 means every read sees the entry as immediately stale.
cache = reload_for_tests(maxsize=8, ttl_seconds=0.0)
builds = 0
async def builder() -> object:
nonlocal builds
builds += 1
return object()
a = await cache.get_or_build("k", builder=builder)
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "stale entry must rebuild a fresh instance"
assert builds == 2
@pytest.mark.asyncio
async def test_cache_evicts_lru_when_full() -> None:
cache = reload_for_tests(maxsize=2, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("a", builder=builder)
_ = await cache.get_or_build("b", builder=builder)
# Re-touch "a" so "b" is now the LRU victim.
a_again = await cache.get_or_build("a", builder=builder)
assert a_again is a
# Inserting "c" should evict "b" (LRU), not "a".
_ = await cache.get_or_build("c", builder=builder)
assert cache.stats()["size"] == 2
# Confirm "a" is still hot (no rebuild) and "b" is gone (rebuild).
a_hit = await cache.get_or_build("a", builder=builder)
assert a_hit is a, "LRU must keep the most-recently-used 'a' entry"
@pytest.mark.asyncio
async def test_cache_concurrent_misses_coalesce_to_single_build() -> None:
"""Two concurrent get_or_build calls on the same key must share one builder."""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
build_started = asyncio.Event()
builds = 0
async def slow_builder() -> object:
nonlocal builds
builds += 1
build_started.set()
# Yield control so the second waiter can race against us.
await asyncio.sleep(0.05)
return object()
task_a = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
# Wait until the first builder has started, then race a second waiter.
await build_started.wait()
task_b = asyncio.create_task(cache.get_or_build("k", builder=slow_builder))
a, b = await asyncio.gather(task_a, task_b)
assert a is b, "coalesced waiters must observe the same value"
assert builds == 1, "concurrent cold misses must collapse to ONE build"
@pytest.mark.asyncio
async def test_cache_does_not_store_failed_builds() -> None:
"""A builder that raises must NOT poison the cache.
The next caller for the same key must run the builder again (not
re-raise the cached exception).
"""
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
attempts = 0
async def flaky_builder() -> object:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("transient")
return object()
with pytest.raises(RuntimeError, match="transient"):
await cache.get_or_build("k", builder=flaky_builder)
# Second call must retry — not re-raise the cached exception.
value = await cache.get_or_build("k", builder=flaky_builder)
assert value is not None
assert attempts == 2
@pytest.mark.asyncio
async def test_cache_invalidate_drops_entry() -> None:
cache = reload_for_tests(maxsize=8, ttl_seconds=60.0)
async def builder() -> object:
return object()
a = await cache.get_or_build("k", builder=builder)
assert cache.invalidate("k") is True
b = await cache.get_or_build("k", builder=builder)
assert a is not b, "post-invalidation lookup must rebuild"
@pytest.mark.asyncio
async def test_cache_invalidate_prefix_drops_matching_entries() -> None:
cache = reload_for_tests(maxsize=16, ttl_seconds=60.0)
async def builder() -> object:
return object()
await cache.get_or_build("user:1:thread:1", builder=builder)
await cache.get_or_build("user:1:thread:2", builder=builder)
await cache.get_or_build("user:2:thread:1", builder=builder)
removed = cache.invalidate_prefix("user:1:")
assert removed == 2
assert cache.stats()["size"] == 1
# The user:2 entry must still be hot (no rebuild).
survivor_value = await cache.get_or_build("user:2:thread:1", builder=builder)
assert survivor_value is not None

View file

@ -7,7 +7,9 @@ import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
end_turn,
get_cancel_event,
is_cancel_requested,
manager,
request_cancel,
reset_cancel,
@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")
def test_request_cancel_creates_event_for_unseen_thread() -> None:
thread_id = "never-seen-cancel"
reset_cancel(thread_id)
assert request_cancel(thread_id) is True
assert get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is True
@pytest.mark.asyncio
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
thread_id = "forced-end-turn"
mw = BusyMutexMiddleware()
runtime = _Runtime(thread_id)
await mw.abefore_agent({}, runtime)
assert manager.lock_for(thread_id).locked()
request_cancel(thread_id)
assert is_cancel_requested(thread_id) is True
end_turn(thread_id)
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False
@pytest.mark.asyncio
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
"""A stale aafter call from attempt A must not unlock attempt B.
Repro flow:
1) attempt A acquires thread lock
2) forced end_turn clears A so retry can proceed
3) attempt B acquires same thread lock
4) stale attempt-A aafter runs late
Expected: B lock remains held.
"""
thread_id = "stale-aafter-lock"
runtime = _Runtime(thread_id)
attempt_a = BusyMutexMiddleware()
attempt_b = BusyMutexMiddleware()
await attempt_a.abefore_agent({}, runtime)
lock = manager.lock_for(thread_id)
assert lock.locked()
end_turn(thread_id)
assert not lock.locked()
await attempt_b.abefore_agent({}, runtime)
assert lock.locked()
# Stale cleanup from attempt A must not release attempt B's lock.
await attempt_a.aafter_agent({}, runtime)
assert lock.locked()
await attempt_b.aafter_agent({}, runtime)

View file

@ -31,18 +31,45 @@ def _clear_all(monkeypatch: pytest.MonkeyPatch) -> None:
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"SURFSENSE_ENABLE_ACTION_LOG",
"SURFSENSE_ENABLE_REVERT_ROUTE",
"SURFSENSE_ENABLE_STREAM_PARITY_V2",
"SURFSENSE_ENABLE_PLUGIN_LOADER",
"SURFSENSE_ENABLE_OTEL",
"SURFSENSE_ENABLE_AGENT_CACHE",
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT",
]:
monkeypatch.delenv(name, raising=False)
def test_defaults_all_off(monkeypatch: pytest.MonkeyPatch) -> None:
def test_defaults_match_shipped_agent_stack(monkeypatch: pytest.MonkeyPatch) -> None:
_clear_all(monkeypatch)
flags = reload_for_tests()
assert isinstance(flags, AgentFeatureFlags)
assert flags.disable_new_agent_stack is False
assert flags.any_new_middleware_enabled() is False
assert flags.enable_context_editing is True
assert flags.enable_compaction_v2 is True
assert flags.enable_retry_after is True
assert flags.enable_model_fallback is False
assert flags.enable_model_call_limit is True
assert flags.enable_tool_call_limit is True
assert flags.enable_tool_call_repair is True
assert flags.enable_doom_loop is True
assert flags.enable_permission is True
assert flags.enable_busy_mutex is True
assert flags.enable_llm_tool_selector is False
assert flags.enable_skills is True
assert flags.enable_specialized_subagents is True
assert flags.enable_kb_planner_runnable is True
assert flags.enable_action_log is True
assert flags.enable_revert_route is True
assert flags.enable_stream_parity_v2 is True
assert flags.enable_plugin_loader is False
assert flags.enable_otel is False
# Phase 2: agent cache is now default-on (the prerequisite tool
# ``db_session`` refactor landed). The companion gp-subagent share
# flag stays default-off pending data on cold-miss frequency.
assert flags.enable_agent_cache is True
assert flags.enable_agent_cache_share_gp_subagent is False
assert flags.any_new_middleware_enabled() is True
def test_master_kill_switch_overrides_individual_flags(
@ -100,21 +127,13 @@ def test_each_flag_can_be_set_independently(monkeypatch: pytest.MonkeyPatch) ->
"enable_kb_planner_runnable": "SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE",
"enable_action_log": "SURFSENSE_ENABLE_ACTION_LOG",
"enable_revert_route": "SURFSENSE_ENABLE_REVERT_ROUTE",
"enable_stream_parity_v2": "SURFSENSE_ENABLE_STREAM_PARITY_V2",
"enable_plugin_loader": "SURFSENSE_ENABLE_PLUGIN_LOADER",
"enable_otel": "SURFSENSE_ENABLE_OTEL",
}
# `enable_otel` is intentionally orthogonal — it does NOT count toward
# ``any_new_middleware_enabled`` because OTel is observability-only and
# ships under its own ``OTEL_EXPORTER_OTLP_ENDPOINT`` requirement.
counts_toward_middleware = {k for k in flag_to_env if k != "enable_otel"}
for attr, env_name in flag_to_env.items():
_clear_all(monkeypatch)
monkeypatch.setenv(env_name, "true")
monkeypatch.setenv(env_name, "false")
flags = reload_for_tests()
assert getattr(flags, attr) is True, f"{attr} did not flip on for {env_name}"
if attr in counts_toward_middleware:
assert flags.any_new_middleware_enabled() is True
else:
assert flags.any_new_middleware_enabled() is False
assert getattr(flags, attr) is False, f"{attr} did not flip off for {env_name}"

View file

@ -0,0 +1,344 @@
"""Tests for ``FlattenSystemMessageMiddleware``.
The middleware exists to defend against Anthropic's "Found 5 cache_control
blocks" 400 when our deepagent middleware stack stacks 5+ text blocks on
the system message and the OpenRouterAnthropic adapter redistributes
``cache_control`` across all of them. The flattening collapses every
all-text system content list to a single string before the LLM call.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from app.agents.new_chat.middleware.flatten_system import (
FlattenSystemMessageMiddleware,
_flatten_text_blocks,
_flattened_request,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# _flatten_text_blocks — pure helper, the heart of the middleware.
# ---------------------------------------------------------------------------
class TestFlattenTextBlocks:
def test_joins_text_blocks_with_double_newline(self) -> None:
blocks = [
{"type": "text", "text": "<surfsense base>"},
{"type": "text", "text": "<filesystem section>"},
{"type": "text", "text": "<skills section>"},
]
assert (
_flatten_text_blocks(blocks)
== "<surfsense base>\n\n<filesystem section>\n\n<skills section>"
)
def test_handles_single_text_block(self) -> None:
blocks = [{"type": "text", "text": "only one"}]
assert _flatten_text_blocks(blocks) == "only one"
def test_handles_empty_list(self) -> None:
assert _flatten_text_blocks([]) == ""
def test_passes_through_bare_string_blocks(self) -> None:
# LangChain content can mix bare strings and dict blocks.
blocks = ["raw string", {"type": "text", "text": "dict block"}]
assert _flatten_text_blocks(blocks) == "raw string\n\ndict block"
def test_returns_none_for_image_block(self) -> None:
# System messages with images are rare — but we never want to
# silently lose the image payload by joining as text.
blocks = [
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png..."}},
]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_for_non_dict_non_str_block(self) -> None:
blocks = [{"type": "text", "text": "hi"}, 42] # type: ignore[list-item]
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_field_missing(self) -> None:
blocks = [{"type": "text"}] # no ``text`` key
assert _flatten_text_blocks(blocks) is None
def test_returns_none_when_text_is_not_string(self) -> None:
blocks = [{"type": "text", "text": ["nested", "list"]}]
assert _flatten_text_blocks(blocks) is None
def test_drops_cache_control_from_inner_blocks(self) -> None:
# The whole point: existing cache_control on inner blocks is
# discarded so LiteLLM's ``cache_control_injection_points`` can
# re-attach exactly one breakpoint after flattening.
blocks = [
{"type": "text", "text": "first"},
{
"type": "text",
"text": "second",
"cache_control": {"type": "ephemeral"},
},
]
flattened = _flatten_text_blocks(blocks)
assert flattened == "first\n\nsecond"
assert "cache_control" not in flattened # type: ignore[operator]
# ---------------------------------------------------------------------------
# _flattened_request — decides when to override and when to no-op.
# ---------------------------------------------------------------------------
def _make_request(system_message: SystemMessage | None) -> Any:
"""Build a minimal ModelRequest stub. We only need .system_message
and .override(system_message=...) the middleware never touches
other fields.
"""
request = MagicMock()
request.system_message = system_message
def override(**kwargs: Any) -> Any:
new_request = MagicMock()
new_request.system_message = kwargs.get(
"system_message", request.system_message
)
new_request.messages = kwargs.get("messages", getattr(request, "messages", []))
new_request.tools = kwargs.get("tools", getattr(request, "tools", []))
return new_request
request.override = override
return request
class TestFlattenedRequest:
def test_collapses_multi_block_system_to_string(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<base>"},
{"type": "text", "text": "<todo>"},
{"type": "text", "text": "<filesystem>"},
{"type": "text", "text": "<skills>"},
{"type": "text", "text": "<subagents>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message, SystemMessage)
assert flattened.system_message.content == (
"<base>\n\n<todo>\n\n<filesystem>\n\n<skills>\n\n<subagents>"
)
def test_no_op_for_string_content(self) -> None:
sys = SystemMessage(content="already a string")
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_for_single_block_list(self) -> None:
# One block already produces one breakpoint — no need to flatten.
sys = SystemMessage(content=[{"type": "text", "text": "single"}])
request = _make_request(sys)
assert _flattened_request(request) is None
def test_no_op_when_system_message_missing(self) -> None:
request = _make_request(None)
assert _flattened_request(request) is None
def test_no_op_when_list_contains_non_text_block(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "look"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]
)
request = _make_request(sys)
assert _flattened_request(request) is None
def test_preserves_additional_kwargs_and_metadata(self) -> None:
# Defensive: nothing in the current chain sets these on a system
# message, but losing them silently when something does in the
# future would be a regression. ``name`` in particular is the only
# ``additional_kwargs`` field that ChatLiteLLM's
# ``_convert_message_to_dict`` propagates onto the wire.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
],
additional_kwargs={"name": "surfsense_system", "x": 1},
response_metadata={"tokens": 42},
)
sys.id = "sys-msg-1"
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert flattened.system_message.content == "a\n\nb"
assert flattened.system_message.additional_kwargs == {
"name": "surfsense_system",
"x": 1,
}
assert flattened.system_message.response_metadata == {"tokens": 42}
assert flattened.system_message.id == "sys-msg-1"
def test_idempotent_when_run_twice(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
request = _make_request(sys)
first = _flattened_request(request)
assert first is not None
# Second pass on the already-flattened request should be a no-op.
# We re-wrap in a request stub since the helper inspects
# ``request.system_message.content``.
second_request = _make_request(first.system_message)
assert _flattened_request(second_request) is None
# ---------------------------------------------------------------------------
# Middleware integration — verify the handler sees a flattened request.
# ---------------------------------------------------------------------------
class TestMiddlewareWrap:
@pytest.mark.asyncio
async def test_async_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = await mw.awrap_model_call(request, handler)
assert result == "ok"
assert isinstance(captured["request"].system_message, SystemMessage)
assert captured["request"].system_message.content == "alpha\n\nbeta"
@pytest.mark.asyncio
async def test_async_passes_through_when_already_string(self) -> None:
sys = SystemMessage(content="just a string")
request = _make_request(sys)
captured: dict[str, Any] = {}
async def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
await mw.awrap_model_call(request, handler)
# Same request object: no override happened.
assert captured["request"] is request
def test_sync_passes_flattened_request_to_handler(self) -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "alpha"},
{"type": "text", "text": "beta"},
]
)
request = _make_request(sys)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
result = mw.wrap_model_call(request, handler)
assert result == "ok"
assert captured["request"].system_message.content == "alpha\n\nbeta"
def test_sync_passes_through_when_no_system_message(self) -> None:
request = _make_request(None)
captured: dict[str, Any] = {}
def handler(req: Any) -> str:
captured["request"] = req
return "ok"
mw = FlattenSystemMessageMiddleware()
mw.wrap_model_call(request, handler)
assert captured["request"] is request
# ---------------------------------------------------------------------------
# Regression guard — pin the worst-case shape that triggered the
# "Found 5" 400 in production. Confirms we collapse 5 blocks to 1 so the
# downstream cache_control_injection_points can only place 1 breakpoint
# on the system message regardless of provider redistribution quirks.
# ---------------------------------------------------------------------------
def test_regression_five_block_system_collapses_to_one_block() -> None:
sys = SystemMessage(
content=[
{"type": "text", "text": "<surfsense base + BASE_AGENT_PROMPT>"},
{"type": "text", "text": "<TodoListMiddleware section>"},
{"type": "text", "text": "<SurfSenseFilesystemMiddleware section>"},
{"type": "text", "text": "<SkillsMiddleware section>"},
{"type": "text", "text": "<SubAgentMiddleware section>"},
]
)
request = _make_request(sys)
flattened = _flattened_request(request)
assert flattened is not None
assert isinstance(flattened.system_message.content, str)
# The exact join doesn't matter for the cache_control accounting —
# only that there is exactly ONE content block when LiteLLM's
# AnthropicCacheControlHook later targets ``role: system``.
assert "<surfsense base" in flattened.system_message.content
assert "<SubAgentMiddleware" in flattened.system_message.content
def test_regression_human_message_not_modified() -> None:
# Sanity: the middleware MUST NOT touch user messages — only the
# system message. Multi-block user content is the path that carries
# image attachments and would lose its image_url block on
# accidental flatten.
sys = SystemMessage(
content=[
{"type": "text", "text": "a"},
{"type": "text", "text": "b"},
]
)
user = HumanMessage(
content=[
{"type": "text", "text": "look at this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}},
]
)
request = _make_request(sys)
request.messages = [user]
flattened = _flattened_request(request)
assert flattened is not None
# System flattened to string …
assert isinstance(flattened.system_message.content, str)
# … user message is untouched (the helper does not even look at it).
assert flattened.messages == [user]
assert isinstance(user.content, list)
assert len(user.content) == 2

View file

@ -27,6 +27,7 @@ class TestDefaultAutoApprovedToolsList:
expected = {
"create_gmail_draft",
"update_gmail_draft",
"create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
@ -41,13 +42,12 @@ class TestDefaultAutoApprovedToolsList:
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
def test_send_tools_are_not_auto_approved(self) -> None:
# External-broadcast tools must always prompt.
# External-broadcast / destructive tools must always prompt.
for tool_name in (
"send_gmail_email",
"send_discord_message",
"send_teams_message",
"delete_notion_page",
"create_calendar_event",
"delete_calendar_event",
):
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (

View file

@ -0,0 +1,370 @@
r"""Tests for ``apply_litellm_prompt_caching`` in
:mod:`app.agents.new_chat.prompt_caching`.
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
``litellm.completion(...)``. The tests below pin its public contract:
1. Always sets BOTH ``index: 0`` and ``index: -1`` injection points so
savings compound across multi-turn conversations on Anthropic-family
providers. ``index: 0`` is used (rather than ``role: system``) because
the deepagent stack accumulates multiple ``SystemMessage``\ s in
``state["messages"]`` and ``role: system`` would tag every one of
them, blowing past Anthropic's 4-block ``cache_control`` cap.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no
OpenAI-only kwargs because the router fans out across providers.
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
skipped rather than raising.
"""
from __future__ import annotations
from typing import Any
import pytest
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class _FakeLLM:
"""Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``.
The helper only inspects ``getattr(llm, "model_kwargs", None)``,
``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple
object suffices we don't need to spin up real LangChain/LiteLLM
machinery for unit tests of the helper's logic.
"""
def __init__(
self,
model: str = "openai/gpt-4o",
model_kwargs: dict[str, Any] | None = None,
) -> None:
self.model = model
self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {}
class ChatLiteLLMRouter:
"""Class-name-only impostor of the real router.
The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"``
(a deliberate stringly-typed check to avoid an import cycle with
``app.services.llm_router_service``). Reusing the same class name here
triggers the same code path without instantiating a real ``Router``.
"""
def __init__(self) -> None:
self.model = "auto"
self.model_kwargs: dict[str, Any] = {}
def _make_cfg(**overrides: Any) -> AgentConfig:
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
defaults: dict[str, Any] = {
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "k",
}
return AgentConfig(**{**defaults, **overrides})
# ---------------------------------------------------------------------------
# (a) Universal injection points
# ---------------------------------------------------------------------------
def test_sets_both_cache_control_injection_points_with_no_config() -> None:
"""Bare call (no agent_config, no thread_id) still sets the two
universal breakpoints these cost nothing on providers that don't
consume them and unlock caching on every supported provider."""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert {"location": "message", "index": 0} in points
assert {"location": "message", "index": -1} in points
assert len(points) == 2
def test_does_not_inject_role_system_breakpoint() -> None:
"""Regression: deliberately AVOID ``role: system`` so we don't tag
every SystemMessage the deepagent ``before_agent`` injectors push
into ``state["messages"]`` (priority, tree, memory, file-intent,
anonymous-doc). Tagging all of them overflows Anthropic's 4-block
``cache_control`` cap and surfaces as
``OpenrouterException: A maximum of 4 blocks with cache_control may
be provided. Found N`` 400s.
"""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert all(p.get("role") != "system" for p in points), (
f"Expected no role=system breakpoint, got: {points}"
)
def test_injection_points_set_for_anthropic_config() -> None:
"""Anthropic-family configs need the marker — verify it lands."""
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
llm = _FakeLLM(model="anthropic/claude-3-5-sonnet")
apply_litellm_prompt_caching(llm, agent_config=cfg)
assert "cache_control_injection_points" in llm.model_kwargs
# ---------------------------------------------------------------------------
# (b) Idempotency / user override wins
# ---------------------------------------------------------------------------
def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None:
"""Users who set their own injection points (e.g. with ``ttl: "1h"``
via ``litellm_params``) keep them the helper merges, never
clobbers."""
user_points = [
{"location": "message", "role": "system", "ttl": "1h"},
]
llm = _FakeLLM(
model_kwargs={"cache_control_injection_points": user_points},
)
apply_litellm_prompt_caching(llm)
assert llm.model_kwargs["cache_control_injection_points"] is user_points
def test_idempotent_when_called_multiple_times() -> None:
"""Build-time + thread-time double-call must be a no-op the second time."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
snapshot = {
"cache_control_injection_points": list(
llm.model_kwargs["cache_control_injection_points"]
),
"prompt_cache_key": llm.model_kwargs["prompt_cache_key"],
"prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"],
}
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
assert (
llm.model_kwargs["cache_control_injection_points"]
== snapshot["cache_control_injection_points"]
)
assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"]
assert (
llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"]
)
def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
"""A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via
``litellm_params``) wins over our default per-thread key."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"})
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc"
# ---------------------------------------------------------------------------
# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
def test_sets_openai_family_extras(provider: str) -> None:
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
cache TTL beyond the default 5-10 min)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Without a thread id we can't construct a per-thread key. Retention
is still useful so we set it (it's free)."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
assert "prompt_cache_key" not in llm.model_kwargs
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
@pytest.mark.parametrize(
"provider",
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
)
def test_no_openai_extras_for_other_providers(provider: str) -> None:
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
skip it. ``cache_control_injection_points`` is still set (universal)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_no_openai_extras_in_auto_mode() -> None:
"""Auto-mode fans out across mixed providers — we can't statically
target OpenAI-only kwargs."""
cfg = AgentConfig.from_auto_mode()
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_no_openai_extras_for_custom_provider() -> None:
"""Custom providers route through arbitrary user-supplied prefixes —
we don't try to infer OpenAI-family compatibility."""
cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
# ---------------------------------------------------------------------------
# (d) ChatLiteLLMRouter — universal injection points only
# ---------------------------------------------------------------------------
def test_router_llm_gets_only_universal_injection_points() -> None:
"""Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must
receive only the universal injection points its requests dispatch
across provider deployments and OpenAI-only kwargs would be wasted
(or stripped by ``drop_params``) on non-OpenAI legs."""
router = ChatLiteLLMRouter()
cfg = _make_cfg(provider="OPENAI")
apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42)
assert "cache_control_injection_points" in router.model_kwargs
assert "prompt_cache_key" not in router.model_kwargs
assert "prompt_cache_retention" not in router.model_kwargs
# ---------------------------------------------------------------------------
# (e) Defensive paths
# ---------------------------------------------------------------------------
def test_handles_llm_with_no_writable_model_kwargs() -> None:
"""Some LLM implementations (e.g. fakes / minimal subclasses) don't
expose a writable ``model_kwargs``. The helper must skip silently
raising would crash the entire LLM build path on a non-critical
optimisation."""
class _ImmutableLLM:
# ``__slots__`` blocks attribute creation, so ``setattr`` raises.
__slots__ = ("model",)
def __init__(self) -> None:
self.model = "openai/gpt-4o"
llm = _ImmutableLLM()
apply_litellm_prompt_caching(llm)
def test_initialises_missing_model_kwargs_dict() -> None:
"""When ``model_kwargs`` is present-but-None (Pydantic v2 default
pattern when no factory is set), the helper initialises it to an
empty dict before mutating."""
class _LazyLLM:
def __init__(self) -> None:
self.model = "openai/gpt-4o"
self.model_kwargs: dict[str, Any] | None = None
llm = _LazyLLM()
apply_litellm_prompt_caching(llm)
assert isinstance(llm.model_kwargs, dict)
assert "cache_control_injection_points" in llm.model_kwargs
def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None:
"""Direct caller path (e.g. ``create_chat_litellm_from_config`` for
YAML configs without a structured ``AgentConfig``): without
``agent_config`` the helper sets only the universal injection points
no OpenAI-family extras even if the prefix says ``openai/``.
Conservative: we'd rather miss the speedup than silently misroute."""
llm = _FakeLLM(model="openai/gpt-4o")
apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99)
assert "cache_control_injection_points" in llm.model_kwargs
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
# ---------------------------------------------------------------------------
# (f) drop_params safety net (regression guard for #19346)
# ---------------------------------------------------------------------------
def test_litellm_drop_params_is_globally_enabled() -> None:
"""``litellm.drop_params=True`` is set globally in
:mod:`app.services.llm_service` so any ``prompt_cache_key`` /
``prompt_cache_retention`` we set on an OpenAI-family config is
auto-stripped if the request later routes to a non-supporting
provider (e.g. via auto-mode router fallback). This test pins that
invariant losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``.
"""
import litellm
import app.services.llm_service # noqa: F401 (side-effect: sets globals)
assert litellm.drop_params is True
# ---------------------------------------------------------------------------
# Regression note: LiteLLM #15696 (multi-content-block last message)
# ---------------------------------------------------------------------------
#
# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]``
# would get ``cache_control`` applied to *every* content block instead
# of only the last one — wasting cache breakpoints and triggering 400s
# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in
# https://github.com/BerriAI/litellm/pull/15699.
#
# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix).
# An end-to-end behavioural test would need to run ``litellm.completion``
# through the Anthropic transformer, which is integration territory and
# better covered by LiteLLM's own test suite. The unit guard here is the
# version pin plus the build-time ``model_kwargs`` shape we verify above.

View file

@ -0,0 +1,117 @@
"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`.
The helper picks the model id fed to ``detect_provider_variant`` so the
right ``<provider_hints>`` block lands in the system prompt. The tests
below pin its preference order:
1. ``agent_config.litellm_params["base_model"]`` (Azure-correct).
2. ``agent_config.model_name``.
3. ``getattr(llm, "model", None)``.
Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would
silently miss every provider regex.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
from app.agents.new_chat.llm_config import AgentConfig
pytestmark = pytest.mark.unit
def _make_cfg(**overrides) -> AgentConfig:
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
defaults = {
"provider": "OPENAI",
"model_name": "x",
"api_key": "k",
}
return AgentConfig(**{**defaults, **overrides})
class _FakeLLM:
"""Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance.
The resolver only reads the ``.model`` attribute via ``getattr``,
matching the established idiom in ``knowledge_search.py`` /
``stream_new_chat.py`` / ``document_summarizer.py``.
"""
def __init__(self, model: str | None) -> None:
self.model = model
def test_prefers_litellm_params_base_model_over_deployment_name() -> None:
"""Azure deployment slug must NOT shadow the underlying model family.
This is the failure mode the helper exists to prevent: a deployment
named ``"azure/prod-chat-001"`` would not match any provider regex
on its own, but the family ``"gpt-4o"`` lives in
``litellm_params["base_model"]`` and routes to ``openai_classic``.
"""
cfg = _make_cfg(
model_name="azure/prod-chat-001",
litellm_params={"base_model": "gpt-4o"},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o"
def test_falls_back_to_model_name_when_litellm_params_is_none() -> None:
cfg = _make_cfg(
model_name="anthropic/claude-3-5-sonnet",
litellm_params=None,
)
got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet"))
assert got == "anthropic/claude-3-5-sonnet"
def test_handles_litellm_params_without_base_model_key() -> None:
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"temperature": 0.5},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_ignores_blank_base_model() -> None:
"""Whitespace-only ``base_model`` must not shadow ``model_name``."""
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"base_model": " "},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_ignores_non_string_base_model() -> None:
"""Defensive: a non-string ``base_model`` should not crash the resolver."""
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"base_model": 42},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_falls_back_to_llm_model_when_no_agent_config() -> None:
"""No ``agent_config`` -> use ``llm.model`` directly. Defensive path
for direct callers; production callers always supply a config."""
assert (
_resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini"))
== "openai/gpt-4o-mini"
)
def test_returns_none_when_nothing_available() -> None:
"""``compose_system_prompt`` treats ``None`` as the ``"default"``
variant and emits no provider block."""
assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None
def test_auto_mode_resolves_to_auto_string() -> None:
"""Auto mode -> ``"auto"``. ``detect_provider_variant("auto")``
returns ``"default"``, which is correct: the child model isn't
known until the LiteLLM Router dispatches."""
cfg = AgentConfig.from_auto_mode()
assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto"

View file

@ -475,3 +475,190 @@ class TestKBSearchPlanSchema:
)
)
assert plan.is_recency_query is False
# ── mentioned_document_ids cross-turn drain ────────────────────────────
class TestKnowledgePriorityMentionDrain:
"""Regression tests for the cross-turn ``mentioned_document_ids`` drain.
The compiled-agent cache reuses a single :class:`KnowledgePriorityMiddleware`
instance across turns of the same thread. ``mentioned_document_ids``
can therefore enter the middleware via two paths:
1. The constructor closure (``__init__(mentioned_document_ids=...)``)
seeded by the cache-miss build on turn 1.
2. ``runtime.context.mentioned_document_ids`` supplied freshly per
turn by the streaming task.
Without the drain fix, an empty ``runtime.context.mentioned_document_ids``
on turn 2 would fall through to the closure (because ``[]`` is falsy in
Python) and replay turn 1's mentions. This class pins down the
correct behaviour: the runtime path is authoritative even when empty,
and the closure is drained the first time the runtime path fires so
no later turn can ever resurrect stale state.
"""
@staticmethod
def _make_runtime(mention_ids: list[int]):
"""Minimal runtime stub exposing only ``runtime.context.mentioned_document_ids``."""
from types import SimpleNamespace
return SimpleNamespace(
context=SimpleNamespace(mentioned_document_ids=mention_ids),
)
@staticmethod
def _planner_llm() -> "FakeLLM":
# Planner returns a stable, non-recency plan so we always land in
# the hybrid-search branch (where ``fetch_mentioned_documents`` is
# invoked alongside the main search).
return FakeLLM(
json.dumps(
{
"optimized_query": "follow up question",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
async def test_runtime_context_overrides_closure_and_drains_it(self, monkeypatch):
"""Turn 1 with mentions in BOTH closure and runtime context: the
runtime path wins AND the closure is drained so a future turn
cannot replay it.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what is in those docs?")]},
runtime=self._make_runtime([1, 2, 3]),
)
assert fetched_ids == [[1, 2, 3]], (
"runtime.context mentions must be the source of truth on turn 1"
)
assert middleware.mentioned_document_ids == [], (
"closure must be drained the first time the runtime path fires "
"so no later turn can replay stale mentions"
)
async def test_empty_runtime_context_does_not_replay_closure_mentions(
self, monkeypatch
):
"""Regression: turn 2 with NO mentions must not surface turn 1's
mentions from the constructor closure.
Before the fix, ``if ctx_mentions:`` treated an empty list as
absent and fell through to ``elif self.mentioned_document_ids:``,
replaying turn 1's mentions. This test pins down the corrected
behaviour.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
# Simulate a cached middleware instance whose closure was seeded
# by a previous turn's cache-miss build (mentions=[1,2,3]).
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[1, 2, 3],
)
# Turn 2: streaming task supplies an EMPTY mention list (no
# mentions on this follow-up turn).
await middleware.abefore_agent(
{"messages": [HumanMessage(content="what about the next steps?")]},
runtime=self._make_runtime([]),
)
assert fetched_ids == [], (
"fetch_mentioned_documents must NOT be called when the runtime "
"context says there are no mentions for this turn"
)
async def test_legacy_path_fires_only_when_runtime_context_absent(
self, monkeypatch
):
"""Backward-compat: if a caller doesn't supply runtime.context (old
non-streaming code path), the closure-injected mentions are still
honoured exactly once and then drained.
"""
fetched_ids: list[list[int]] = []
async def fake_fetch_mentioned_documents(*, document_ids, search_space_id):
fetched_ids.append(list(document_ids))
return []
async def fake_search_knowledge_base(**_kwargs):
return []
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.fetch_mentioned_documents",
fake_fetch_mentioned_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
middleware = KnowledgeBaseSearchMiddleware(
llm=self._planner_llm(),
search_space_id=42,
mentioned_document_ids=[7, 8],
)
# First call: no runtime → legacy path uses the closure.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="initial question")]},
runtime=None,
)
# Second call: still no runtime — closure already drained, so no replay.
await middleware.abefore_agent(
{"messages": [HumanMessage(content="follow up")]},
runtime=None,
)
assert fetched_ids == [[7, 8]], (
"legacy path must honour the closure exactly once and then drain it"
)
assert middleware.mentioned_document_ids == []

View file

@ -0,0 +1,110 @@
"""Unit tests for ``supports_image_input`` derivation on BYOK chat config
endpoints (``GET /new-llm-configs`` list, ``GET /new-llm-configs/{id}``).
There is no DB column for ``supports_image_input`` on
``NewLLMConfig`` the value is resolved at the API boundary by
``derive_supports_image_input`` so the new-chat selector / streaming
task can read the same field shape regardless of source (BYOK vs YAML
vs OpenRouter dynamic). Default-allow on unknown so we don't lock the
user out of their own model choice.
"""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from uuid import uuid4
import pytest
from app.db import LiteLLMProvider
from app.routes import new_llm_config_routes
pytestmark = pytest.mark.unit
def _byok_row(
*,
id_: int,
model_name: str,
base_model: str | None = None,
provider: LiteLLMProvider = LiteLLMProvider.OPENAI,
custom_provider: str | None = None,
) -> object:
"""Mimic the SQLAlchemy row's attribute surface; ``model_validate``
walks ``from_attributes=True`` so a ``SimpleNamespace`` is enough.
``provider`` is a real ``LiteLLMProvider`` enum value so Pydantic's
enum validator accepts it same as the ORM row would carry."""
return SimpleNamespace(
id=id_,
name=f"BYOK-{id_}",
description=None,
provider=provider,
custom_provider=custom_provider,
model_name=model_name,
api_key="sk-byok",
api_base=None,
litellm_params={"base_model": base_model} if base_model else None,
system_instructions="",
use_default_system_instructions=True,
citations_enabled=True,
created_at=datetime.now(tz=UTC),
search_space_id=42,
user_id=uuid4(),
)
def test_serialize_byok_known_vision_model_resolves_true():
"""The catalog resolver consults LiteLLM's map for ``gpt-4o`` ->
True. The serialized row carries that value through to the
``NewLLMConfigRead`` schema."""
row = _byok_row(id_=1, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
assert serialized.id == 1
assert serialized.model_name == "gpt-4o"
def test_serialize_byok_unknown_model_default_allows():
"""Unknown / unmapped: default-allow. The streaming-task safety net
is the actual block, and it requires LiteLLM to *explicitly* say
text-only so a brand new BYOK model should not be pre-judged."""
row = _byok_row(
id_=2,
model_name="brand-new-model-x9-unmapped",
provider=LiteLLMProvider.CUSTOM,
custom_provider="brand_new_proxy",
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_uses_base_model_when_present():
"""Azure-style: ``model_name`` is the deployment id, ``base_model``
inside ``litellm_params`` is the canonical sku LiteLLM knows. The
helper must consult ``base_model`` first or unrecognised deployment
ids would shadow the real capability."""
row = _byok_row(
id_=3,
model_name="my-azure-deployment-id-no-litellm-knows-this",
base_model="gpt-4o",
provider=LiteLLMProvider.AZURE_OPENAI,
)
serialized = new_llm_config_routes._serialize_byok_config(row)
assert serialized.supports_image_input is True
def test_serialize_byok_returns_pydantic_read_model():
"""The route now returns ``NewLLMConfigRead`` (not the raw ORM) so
the schema additions are guaranteed to be present in the API
surface. This guards against a future regression where someone
deletes the augmentation step and falls back to ORM passthrough."""
from app.schemas import NewLLMConfigRead
row = _byok_row(id_=4, model_name="gpt-4o")
serialized = new_llm_config_routes._serialize_byok_config(row)
assert isinstance(serialized, NewLLMConfigRead)

View file

@ -0,0 +1,184 @@
"""Unit tests for ``is_premium`` derivation on the global image-gen and
vision-LLM list endpoints.
Chat globals (``GET /global-llm-configs``) already emit
``is_premium = (billing_tier == "premium")``. Image and vision did not,
which made the new-chat ``model-selector`` render the Free/Premium badge
on the Chat tab but skip it on the Image and Vision tabs (the selector
keys its badge logic off ``is_premium``). These tests pin parity:
* YAML free entry ``is_premium=False``
* YAML premium entry ``is_premium=True``
* OpenRouter dynamic premium entry ``is_premium=True``
* Auto stub (always emitted when at least one config is present)
``is_premium=False``
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_IMAGE_FIXTURE: list[dict] = [
{
"id": -1,
"name": "DALL-E 3",
"provider": "OPENAI",
"model_name": "dall-e-3",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "GPT-Image 1 (premium)",
"provider": "OPENAI",
"model_name": "gpt-image-1",
"api_key": "sk-test",
"billing_tier": "premium",
},
{
"id": -20_001,
"name": "google/gemini-2.5-flash-image (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
_VISION_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o Vision",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
},
{
"id": -2,
"name": "Claude 3.5 Sonnet (premium)",
"provider": "ANTHROPIC",
"model_name": "claude-3-5-sonnet",
"api_key": "sk-ant-test",
"billing_tier": "premium",
},
{
"id": -30_001,
"name": "openai/gpt-4o (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "premium",
},
]
# =============================================================================
# Image generation
# =============================================================================
@pytest.mark.asyncio
async def test_global_image_gen_configs_emit_is_premium(monkeypatch):
"""Each emitted config must carry ``is_premium`` derived server-side
from ``billing_tier``. The Auto stub is always free.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config, "GLOBAL_IMAGE_GEN_CONFIGS", _IMAGE_FIXTURE, raising=False
)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub is always emitted when at least one global config exists,
# and it must always declare itself free (Auto-mode billing-tier
# surfacing is a separate follow-up).
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
# YAML free entry — ``is_premium=False``
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
# YAML premium entry — ``is_premium=True``
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
# OpenRouter dynamic premium entry — same field, same derivation
assert by_id[-20_001]["is_premium"] is True
assert by_id[-20_001]["billing_tier"] == "premium"
# Every emitted dict (including Auto) must have the field — never missing.
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_image_gen_configs_no_globals_no_auto_stub(monkeypatch):
"""When there are no global configs at all, the endpoint emits an
empty list (no Auto stub) Auto mode would have nothing to route to.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(config, "GLOBAL_IMAGE_GEN_CONFIGS", [], raising=False)
payload = await image_generation_routes.get_global_image_gen_configs(user=None)
assert payload == []
# =============================================================================
# Vision LLM
# =============================================================================
@pytest.mark.asyncio
async def test_global_vision_llm_configs_emit_is_premium(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(
config, "GLOBAL_VISION_LLM_CONFIGS", _VISION_FIXTURE, raising=False
)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
assert 0 in by_id, "Auto stub should be emitted when at least one config exists"
assert by_id[0]["is_premium"] is False
assert by_id[0]["billing_tier"] == "free"
assert by_id[-1]["is_premium"] is False
assert by_id[-1]["billing_tier"] == "free"
assert by_id[-2]["is_premium"] is True
assert by_id[-2]["billing_tier"] == "premium"
assert by_id[-30_001]["is_premium"] is True
assert by_id[-30_001]["billing_tier"] == "premium"
for cfg in payload:
assert "is_premium" in cfg, f"is_premium missing from {cfg.get('id')}"
assert isinstance(cfg["is_premium"], bool)
@pytest.mark.asyncio
async def test_global_vision_llm_configs_no_globals_no_auto_stub(monkeypatch):
from app.config import config
from app.routes import vision_llm_routes
monkeypatch.setattr(config, "GLOBAL_VISION_LLM_CONFIGS", [], raising=False)
payload = await vision_llm_routes.get_global_vision_llm_configs(user=None)
assert payload == []

View file

@ -0,0 +1,106 @@
"""Unit tests for ``supports_image_input`` derivation on the chat global
config endpoint (``GET /global-new-llm-configs``).
Resolution order (matches ``new_llm_config_routes.get_global_new_llm_configs``):
1. Explicit ``supports_image_input`` on the cfg dict (set by the YAML
loader for operator overrides, or by the OpenRouter integration from
``architecture.input_modalities``) wins.
2. ``derive_supports_image_input`` helper default-allow on unknown
models, only False when LiteLLM / OR modalities are definitive.
The flag is purely informational at the API boundary. The streaming
task safety net (``is_known_text_only_chat_model``) is the actual block,
and it requires LiteLLM to *explicitly* mark the model as text-only.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_FIXTURE: list[dict] = [
{
"id": -1,
"name": "GPT-4o (explicit true)",
"description": "vision-capable, explicit YAML override",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
"supports_image_input": True,
},
{
"id": -2,
"name": "DeepSeek V3 (explicit false)",
"description": "OpenRouter dynamic — modality-derived false",
"provider": "OPENROUTER",
"model_name": "deepseek/deepseek-v3.2-exp",
"api_key": "sk-or-test",
"api_base": "https://openrouter.ai/api/v1",
"billing_tier": "free",
"supports_image_input": False,
},
{
"id": -10_010,
"name": "Unannotated GPT-4o",
"description": "no flag set — resolver should derive True via LiteLLM",
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "sk-test",
"billing_tier": "free",
# supports_image_input intentionally absent
},
{
"id": -10_011,
"name": "Unannotated unknown model",
"description": "unmapped — default-allow True",
"provider": "CUSTOM",
"custom_provider": "brand_new_proxy",
"model_name": "brand-new-model-x9",
"api_key": "sk-test",
"billing_tier": "free",
},
]
@pytest.mark.asyncio
async def test_global_new_llm_configs_emit_supports_image_input(monkeypatch):
"""Each emitted chat config carries ``supports_image_input`` as a
bool. Explicit values win; unannotated entries are resolved via the
helper (default-allow True)."""
from app.config import config
from app.routes import new_llm_config_routes
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", _FIXTURE, raising=False)
payload = await new_llm_config_routes.get_global_new_llm_configs(user=None)
by_id = {c["id"]: c for c in payload}
# Auto stub: optimistic True so the user can keep Auto selected with
# vision-capable deployments somewhere in the pool.
assert 0 in by_id, "Auto stub should be emitted when configs exist"
assert by_id[0]["supports_image_input"] is True
assert by_id[0]["is_auto_mode"] is True
# Explicit True is preserved.
assert by_id[-1]["supports_image_input"] is True
# Explicit False is preserved (the exact failure mode the safety net
# guards against — DeepSeek V3 over OpenRouter would 404 with "No
# endpoints found that support image input").
assert by_id[-2]["supports_image_input"] is False
# Unannotated GPT-4o: resolver consults LiteLLM, which says vision.
assert by_id[-10_010]["supports_image_input"] is True
# Unknown / unmapped model: default-allow rather than pre-judge.
assert by_id[-10_011]["supports_image_input"] is True
for cfg in payload:
assert "supports_image_input" in cfg, (
f"supports_image_input missing from {cfg.get('id')}"
)
assert isinstance(cfg["supports_image_input"], bool)

View file

@ -0,0 +1,138 @@
"""Unit tests for the image-generation route's billing-resolution helper.
End-to-end "POST /image-generations returns 402" coverage requires the
integration harness (real DB, real auth) and lives in
``tests/integration/document_upload/`` alongside the other quota tests.
This unit test focuses on the new ``_resolve_billing_for_image_gen``
helper which:
* Returns ``free`` for Auto mode, even when premium configs exist
(Auto-mode billing-tier surfacing is a follow-up).
* Returns ``free`` for user-owned BYOK configs (positive IDs).
* Returns the global config's ``billing_tier`` for negative IDs.
* Honours the per-config ``quota_reserve_micros`` override when present.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_resolve_billing_for_auto_mode(monkeypatch):
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, # Not consumed on this code path.
config_id=0, # IMAGE_GEN_AUTO_MODE_ID
search_space=search_space,
)
assert tier == "free"
assert model == "auto"
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_for_premium_global_config(monkeypatch):
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"billing_tier": "premium",
"quota_reserve_micros": 75_000,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash-image",
"billing_tier": "free",
},
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=None)
# Premium with override.
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=-1, search_space=search_space
)
assert tier == "premium"
assert model == "openai/gpt-image-1"
assert reserve == 75_000
# Free, no override → falls back to default.
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=-2, search_space=search_space
)
assert tier == "free"
# Provider-prefixed model string for OpenRouter.
assert "google/gemini-2.5-flash-image" in model
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_for_user_owned_byok_is_free():
"""User-owned BYOK configs (positive IDs) cost the user nothing on
our side they pay the provider directly. Always free.
"""
from app.routes import image_generation_routes
from app.services.billable_calls import DEFAULT_IMAGE_RESERVE_MICROS
search_space = SimpleNamespace(image_generation_config_id=None)
tier, model, reserve = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=42, search_space=search_space
)
assert tier == "free"
assert model == "user_byok"
assert reserve == DEFAULT_IMAGE_RESERVE_MICROS
@pytest.mark.asyncio
async def test_resolve_billing_falls_back_to_search_space_default(monkeypatch):
"""When the request omits ``image_generation_config_id``, the helper
must consult the search space's default — so a search space pinned
to a premium global config still gates new requests by quota.
"""
from app.config import config
from app.routes import image_generation_routes
monkeypatch.setattr(
config,
"GLOBAL_IMAGE_GEN_CONFIGS",
[
{
"id": -7,
"provider": "OPENAI",
"model_name": "gpt-image-1",
"billing_tier": "premium",
}
],
raising=False,
)
search_space = SimpleNamespace(image_generation_config_id=-7)
(
tier,
model,
_reserve,
) = await image_generation_routes._resolve_billing_for_image_gen(
session=None, config_id=None, search_space=search_space
)
assert tier == "premium"
assert model == "openai/gpt-image-1"

View file

@ -0,0 +1,436 @@
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
Validates the resolver used by Celery podcast/video tasks to compute
``(owner_user_id, billing_tier, base_model)`` from a search space and its
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
``stream_new_chat.py:2294-2351`` and is the single integration point that
prevents Auto-mode podcast/video from leaking premium credit.
Coverage:
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
global returns ``("premium", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
global returns ``("free", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
always ``"free"``.
* Auto mode + ``thread_id=None`` fallback to ``("free", "auto")`` without
hitting the pin service.
* Negative id (no Auto) uses ``get_global_llm_config``'s
``billing_tier``.
* Positive id (user BYOK) always ``"free"``.
* Search space not found raises ``ValueError``.
* ``agent_llm_id`` is None raises ``ValueError``.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from uuid import UUID, uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
class _FakeSession:
"""Tiny AsyncSession stub.
``responses`` is a list of objects to return from successive
``execute()`` calls (in order). The resolver makes at most two
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
lookup), so two queued responses cover the matrix.
"""
def __init__(self, responses: list):
self._responses = list(responses)
async def execute(self, _stmt):
if not self._responses:
return _FakeExecResult(None)
return _FakeExecResult(self._responses.pop(0))
async def commit(self) -> None:
pass
@dataclass
class _FakePinResolution:
resolved_llm_config_id: int
resolved_tier: str = "premium"
from_existing_pin: bool = False
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
return SimpleNamespace(
id=42,
agent_llm_id=agent_llm_id,
user_id=user_id,
)
def _make_byok_config(
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
) -> SimpleNamespace:
return SimpleNamespace(
id=id_,
model_name=model_name,
litellm_params={"base_model": base_model} if base_model else {},
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
"""Auto + thread → pin service resolves to negative-id premium config →
resolver returns ``("premium", <base_model>)``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
# Mock the pin service to return a concrete premium config id.
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
assert selected_llm_config_id == 0
assert thread_id == 99
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
# Mock global config lookup to return a premium entry.
def _fake_get_global(cfg_id):
if cfg_id == -1:
return {
"id": -1,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
return None
# Lazy imports inside the resolver — patch the *target* modules so the
# imported names resolve to our fakes.
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
"""Auto + thread → pin returns negative-id free config → resolver
returns ``("free", <base_model>)``. Same path the pin service takes for
out-of-credit users (graceful degradation)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
def _fake_get_global(cfg_id):
if cfg_id == -3:
return {
"id": -3,
"model_name": "openrouter/free-model",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/free-model"},
}
return None
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/free-model"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
"""Auto + thread → pin returns positive-id BYOK config → resolver
returns ``("free", ...)`` (BYOK is always free per
``AgentConfig.from_new_llm_config``)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
byok_cfg = _make_byok_config(
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
)
session = _FakeSession([search_space, byok_cfg])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3-haiku"
@pytest.mark.asyncio
async def test_auto_mode_without_thread_id_falls_back_to_free():
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
the pin service. Forward-compat fallback for any future direct-API
entrypoint that doesn't have a chat thread."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
"""If the pin service raises ``ValueError`` (thread missing /
mismatched search space), the resolver should log and return free
rather than killing the whole task."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(*args, **kwargs):
raise ValueError("thread missing")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_negative_id_premium_global_returns_premium(monkeypatch):
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
return its ``billing_tier``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_negative_id_free_global_returns_free(monkeypatch):
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "openrouter/some-free",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/some-free"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/some-free"
@pytest.mark.asyncio
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
"""When the global config has no ``litellm_params.base_model``, the
resolver falls back to ``model_name`` matching chat's behavior."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "fallback-model",
"billing_tier": "premium",
# No litellm_params.
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
_, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert tier == "premium"
assert base_model == "fallback-model"
@pytest.mark.asyncio
async def test_positive_id_byok_is_always_free():
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
regardless of underlying provider tier."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
session = _FakeSession([search_space, byok_cfg])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3.5-sonnet"
@pytest.mark.asyncio
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
"""If the BYOK config row is missing/deleted but the search space still
points at it, the resolver still returns free (no debit) with an empty
base_model billable_call's premium path is skipped, no harm done."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == ""
@pytest.mark.asyncio
async def test_search_space_not_found_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
session = _FakeSession([None])
with pytest.raises(ValueError, match="Search space"):
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
@pytest.mark.asyncio
async def test_agent_llm_id_none_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
with pytest.raises(ValueError, match="agent_llm_id"):
await _resolve_agent_billing_for_search_space(session, search_space_id=42)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,286 @@
"""Image-aware extension of the Auto-pin resolver.
When the current chat turn carries an ``image_url`` block, the pin
resolver must:
1. Filter the candidate pool to vision-capable cfgs so a freshly
selected pin can never be text-only.
2. Treat any existing pin whose capability is False as invalid (force
re-pin), even when it would otherwise be reused as the thread's
stable model.
3. Raise ``ValueError`` (mapped to the friendly
``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error in the streaming
task) when no vision-capable cfg is available instead of silently
pinning text-only and 404-ing at the provider.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
import pytest
from app.services.auto_model_pin_service import (
clear_healthy,
clear_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _reset_caches():
clear_runtime_cooldown()
clear_healthy()
yield
clear_runtime_cooldown()
clear_healthy()
@dataclass
class _FakeQuotaResult:
allowed: bool
class _FakeExecResult:
def __init__(self, thread):
self._thread = thread
def unique(self):
return self
def scalar_one_or_none(self):
return self._thread
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
async def commit(self):
self.commit_count += 1
def _thread(*, pinned: int | None = None):
return SimpleNamespace(id=1, search_space_id=10, pinned_llm_config_id=pinned)
def _vision_cfg(id_: int, *, tier: str = "free", quality: int = 80) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"model_name": f"vision-{id_}",
"api_key": "k",
"billing_tier": tier,
"supports_image_input": True,
"auto_pin_tier": "A",
"quality_score": quality,
}
def _text_only_cfg(id_: int, *, tier: str = "free", quality: int = 90) -> dict:
return {
"id": id_,
"provider": "OPENAI",
"model_name": f"text-{id_}",
"api_key": "k",
"billing_tier": tier,
# Higher quality than the vision cfgs — so a bug that ignores
# the image flag would surface as the resolver picking this one.
"supports_image_input": False,
"auto_pin_tier": "A",
"quality_score": quality,
}
async def _premium_allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
@pytest.mark.asyncio
async def test_image_turn_filters_out_text_only_candidates(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
# The thread should be pinned to the vision cfg even though the
# text-only cfg has a higher quality score.
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_image_turn_force_repins_stale_text_only_pin(monkeypatch):
"""An existing text-only pin must be invalidated when the next turn
requires image input. The non-image path would happily reuse it."""
from app.config import config
session = _FakeSession(_thread(pinned=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_image_turn_reuses_existing_vision_pin(monkeypatch):
"""If the thread is already pinned to a vision-capable cfg, reuse it
same as the non-image path. Image-aware filtering must not force
spurious re-pins."""
from app.config import config
session = _FakeSession(_thread(pinned=-2))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _vision_cfg(-2), _vision_cfg(-3, quality=70)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_image_turn_with_no_vision_candidates_raises(monkeypatch):
"""The friendly-error path: no vision-capable cfg in the pool -> raise
``ValueError`` whose message contains ``vision-capable`` so the
streaming task can map it to ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT``."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1), _text_only_cfg(-2)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
with pytest.raises(ValueError, match="vision-capable"):
await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
@pytest.mark.asyncio
async def test_non_image_turn_keeps_text_only_in_pool(monkeypatch):
"""Regression guard: the image flag must default False and not affect
a normal text-only turn text-only cfgs remain selectable."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[_text_only_cfg(-1)],
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
@pytest.mark.asyncio
async def test_image_turn_unannotated_cfg_resolves_via_helper(monkeypatch):
"""A YAML cfg that omits ``supports_image_input`` falls through to
``derive_supports_image_input`` (LiteLLM-driven). For ``gpt-4o``
that returns True, so the cfg should be a valid candidate."""
from app.config import config
session = _FakeSession(_thread())
cfg_unannotated_vision = {
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-4o", # known vision model in LiteLLM map
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "A",
"quality_score": 80,
# NOTE: no supports_image_input key
}
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [cfg_unannotated_vision])
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_premium_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id=None,
selected_llm_config_id=0,
requires_image_input=True,
)
assert result.resolved_llm_config_id == -2

View file

@ -0,0 +1,559 @@
"""Unit tests for the ``billable_call`` async context manager.
Covers the per-call premium-credit lifecycle for image generation and
vision LLM extraction:
* Free configs bypass reserve/finalize but still write an audit row.
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
route layer).
* Successful premium calls reserve, yield the accumulator, then finalize
with the LiteLLM-reported actual cost and write an audit row.
* Failed premium calls release the reservation so credit isn't leaked.
* All quota DB ops happen inside their OWN ``shielded_async_session``,
isolating them from the caller's transaction (issue A).
"""
from __future__ import annotations
import asyncio
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeQuotaResult:
def __init__(
self,
*,
allowed: bool,
used: int = 0,
limit: int = 5_000_000,
remaining: int = 5_000_000,
) -> None:
self.allowed = allowed
self.used = used
self.limit = limit
self.remaining = remaining
class _FakeSession:
"""Minimal AsyncSession stub — record commits for assertion."""
def __init__(self) -> None:
self.committed = False
self.added: list[Any] = []
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
pass
async def close(self) -> None:
pass
@contextlib.asynccontextmanager
async def _fake_shielded_session():
s = _FakeSession()
_SESSIONS_USED.append(s)
yield s
_SESSIONS_USED: list[_FakeSession] = []
def _patch_isolation_layer(
monkeypatch, *, reserve_result, finalize_result=None, finalize_exc=None
):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
finalize_calls: list[dict[str, Any]] = []
release_calls: list[dict[str, Any]] = []
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
reserve_calls.append(
{
"user_id": user_id,
"reserve_micros": reserve_micros,
"request_id": request_id,
}
)
return reserve_result
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
if finalize_exc is not None:
raise finalize_exc
finalize_calls.append(
{
"user_id": user_id,
"actual_micros": actual_micros,
"reserved_micros": reserved_micros,
}
)
return finalize_result or _FakeQuotaResult(allowed=True)
async def _fake_release(*, db_session, user_id, reserved_micros):
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
record_calls: list[dict[str, Any]] = []
async def _fake_record(session, **kwargs):
record_calls.append(kwargs)
return object()
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_reserve",
_fake_reserve,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_finalize",
_fake_finalize,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_release",
_fake_release,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.shielded_async_session",
_fake_shielded_session,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_fake_record,
raising=False,
)
return {
"reserve": reserve_calls,
"finalize": finalize_calls,
"release": release_calls,
"record": record_calls,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
) as acc:
# Simulate a captured cost — the accumulator is fed by the LiteLLM
# callback in real life, here we add it manually.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
# Free still audits.
assert len(spies["record"]) == 1
assert spies["record"][0]["usage_type"] == "image_generation"
assert spies["record"][0]["cost_micros"] == 37_000
@pytest.mark.asyncio
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "image_generation"
assert err.used_micros == 5_000_000
assert err.limit_micros == 5_000_000
assert err.remaining_micros == 0
# Reserve was attempted, but no finalize/release on a denied reserve
# — the reservation never actually held credit.
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
assert spies["release"] == []
# Denied premium calls do NOT create an audit row (no work happened).
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
# LiteLLM callback would normally fill this — simulate $0.04 image.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert spies["reserve"][0]["reserve_micros"] == 50_000
assert len(spies["finalize"]) == 1
assert spies["finalize"][0]["actual_micros"] == 40_000
assert spies["finalize"][0]["reserved_micros"] == 50_000
assert spies["release"] == []
# And audit row written with the actual debited cost.
assert spies["record"][0]["cost_micros"] == 40_000
# Each quota op opened its OWN session — proves session isolation.
assert len(_SESSIONS_USED) >= 3
# Sessions used should each have committed (or be the audit one which commits).
for _s in _SESSIONS_USED:
# finalize/reserve happen via TokenQuotaService.* which we stub —
# they don't actually call commit on our fake session, but the
# audit session does. We just assert >=1 session committed.
pass
assert any(s.committed for s in _SESSIONS_USED)
@pytest.mark.asyncio
async def test_premium_failure_releases_reservation(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _ProviderError(Exception):
pass
with pytest.raises(_ProviderError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
raise _ProviderError("OpenRouter 503")
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
# Failure path: release the held reservation.
assert len(spies["release"]) == 1
assert spies["release"][0]["reserved_micros"] == 50_000
@pytest.mark.asyncio
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
"""When ``quota_reserve_micros_override`` is None we fall back to
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
Vision LLM calls take this path (token-priced models).
"""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
captured_estimator_calls: list[dict[str, Any]] = []
def _fake_estimate(*, base_model, quota_reserve_tokens):
captured_estimator_calls.append(
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
)
return 12_345
monkeypatch.setattr(
"app.services.billable_calls.estimate_call_reserve_micros",
_fake_estimate,
raising=False,
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
usage_type="vision_extraction",
):
pass
assert captured_estimator_calls == [
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
]
assert spies["reserve"][0]["reserve_micros"] == 12_345
@pytest.mark.asyncio
async def test_premium_finalize_failure_propagates_and_releases(monkeypatch):
from app.services.billable_calls import BillingSettlementError, billable_call
class _FinalizeError(RuntimeError):
pass
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(allowed=True),
finalize_exc=_FinalizeError("db finalize failed"),
)
user_id = uuid4()
with pytest.raises(BillingSettlementError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert len(spies["release"]) == 1
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_audit_commit_hang_times_out_after_finalize(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _HangingCommitSession(_FakeSession):
async def commit(self) -> None:
await asyncio.sleep(60)
@contextlib.asynccontextmanager
async def _hanging_session_factory():
s = _HangingCommitSession()
_SESSIONS_USED.append(s)
yield s
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
billable_session_factory=_hanging_session_factory,
audit_timeout_seconds=0.01,
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert len(spies["finalize"]) == 1
assert len(spies["record"]) == 1
assert spies["release"] == []
@pytest.mark.asyncio
async def test_free_audit_failure_is_best_effort(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
async def _failing_record(_session, **_kwargs):
raise RuntimeError("audit insert failed")
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_failing_record,
raising=False,
)
async with billable_call(
user_id=uuid4(),
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
audit_timeout_seconds=0.01,
) as acc:
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
"""Free podcast configs must skip reserve/finalize but still emit a
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
have full audit coverage of free-tier agent runs."""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openrouter/some-free-model",
quota_reserve_micros_override=200_000,
usage_type="podcast_generation",
thread_id=99,
call_details={"podcast_id": 7, "title": "Test Podcast"},
) as acc:
# Two transcript LLM calls aggregated into one accumulator.
acc.add(
model="openrouter/some-free-model",
prompt_tokens=1500,
completion_tokens=8000,
total_tokens=9500,
cost_micros=0,
call_kind="chat",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
assert row["thread_id"] is None
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
@pytest.mark.asyncio
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
"""Premium video-presentation runs that hit a denied reservation must
raise ``QuotaInsufficientError`` *before* the graph runs and must not
emit an audit row (no work happened)."""
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="gpt-5.4",
quota_reserve_micros_override=1_000_000,
usage_type="video_presentation_generation",
thread_id=99,
call_details={"video_presentation_id": 12, "title": "Test Video"},
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "video_presentation_generation"
assert err.remaining_micros == 500_000
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
assert spies["finalize"] == []
assert spies["release"] == []
assert spies["record"] == []

View file

@ -0,0 +1,177 @@
"""Defense-in-depth: image-gen call sites must not let an empty
``api_base`` fall through to LiteLLM's module-global ``litellm.api_base``.
The bug repro: an OpenRouter image-gen config ships
``api_base=""``. The pre-fix call site in
``image_generation_routes._execute_image_generation`` did
``if cfg.get("api_base"): kwargs["api_base"] = cfg["api_base"]`` which
silently dropped the empty string. LiteLLM then fell back to
``litellm.api_base`` (commonly inherited from ``AZURE_OPENAI_ENDPOINT``)
and OpenRouter's ``image_generation/transformation`` appended
``/chat/completions`` to it 404 ``Resource not found``.
This test pins the post-fix behaviour: with an empty ``api_base`` in
the config, the call site MUST set ``api_base`` to OpenRouter's public
URL instead of leaving it unset.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
"""The global-config branch (``config_id < 0``) of
``_execute_image_generation`` must apply the resolver and pin
``api_base`` to OpenRouter when the config ships an empty string.
"""
from app.routes import image_generation_routes
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "", # the original bug shape
"api_version": None,
"litellm_params": {},
}
captured: dict = {}
async def fake_aimage_generation(**kwargs):
captured.update(kwargs)
return MagicMock(model_dump=lambda: {"data": []}, _hidden_params={})
image_gen = MagicMock()
image_gen.image_generation_config_id = cfg["id"]
image_gen.prompt = "test"
image_gen.n = 1
image_gen.quality = None
image_gen.size = None
image_gen.style = None
image_gen.response_format = None
image_gen.model = None
search_space = MagicMock()
search_space.image_generation_config_id = cfg["id"]
session = MagicMock()
with (
patch.object(
image_generation_routes,
"_get_global_image_gen_config",
return_value=cfg,
),
patch.object(
image_generation_routes,
"aimage_generation",
side_effect=fake_aimage_generation,
),
):
await image_generation_routes._execute_image_generation(
session=session, image_gen=image_gen, search_space=search_space
)
# The whole point of the fix: even with empty ``api_base`` in the
# config, we forward OpenRouter's public URL so the call doesn't
# inherit an Azure endpoint.
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
@pytest.mark.asyncio
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
"""Same defense at the agent tool entry point — both surfaces share
the same OpenRouter config payloads."""
from app.agents.new_chat.tools import generate_image as gi_module
cfg = {
"id": -20_001,
"name": "GPT Image 1 (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-image-1",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
}
captured: dict = {}
async def fake_aimage_generation(**kwargs):
captured.update(kwargs)
response = MagicMock()
response.model_dump.return_value = {
"data": [{"url": "https://example.com/x.png"}]
}
response._hidden_params = {"model": "openrouter/openai/gpt-image-1"}
return response
search_space = MagicMock()
search_space.id = 1
search_space.image_generation_config_id = cfg["id"]
session_cm = AsyncMock()
session = AsyncMock()
session_cm.__aenter__.return_value = session
scalars = MagicMock()
scalars.first.return_value = search_space
exec_result = MagicMock()
exec_result.scalars.return_value = scalars
session.execute.return_value = exec_result
session.add = MagicMock()
session.commit = AsyncMock()
session.refresh = AsyncMock()
# ``refresh(db_image_gen)`` needs to populate ``id`` for token URL fallback.
async def _refresh(obj):
obj.id = 1
session.refresh.side_effect = _refresh
with (
patch.object(gi_module, "shielded_async_session", return_value=session_cm),
patch.object(gi_module, "_get_global_image_gen_config", return_value=cfg),
patch.object(
gi_module, "aimage_generation", side_effect=fake_aimage_generation
),
patch.object(
gi_module, "is_image_gen_auto_mode", side_effect=lambda cid: cid == 0
),
):
tool = gi_module.create_generate_image_tool(
search_space_id=1, db_session=MagicMock()
)
await tool.ainvoke({"prompt": "a cat", "n": 1})
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-image-1"
def test_image_gen_router_deployment_sets_api_base_when_config_empty():
"""The Auto-mode router pool must also resolve ``api_base`` when an
OpenRouter config ships an empty string. The deployment dict is fed
straight to ``litellm.Router``, so a missing ``api_base`` would
leak the same way as the direct call sites.
"""
from app.services.image_gen_router_service import ImageGenRouterService
deployment = ImageGenRouterService._config_to_deployment(
{
"model_name": "openai/gpt-image-1",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-image-1"

View file

@ -0,0 +1,226 @@
"""LLMRouterService pool-filter / rebuild tests.
These tests focus on the *config plumbing* (which configs enter the router
pool, rebuild resets state correctly). They stub out the underlying
``litellm.Router`` so we don't need real API keys or network access.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.services.llm_router_service import LLMRouterService
pytestmark = pytest.mark.unit
def _fake_yaml_config(
*,
id: int,
model_name: str,
billing_tier: str = "free",
) -> dict:
return {
"id": id,
"name": f"yaml-{id}",
"provider": "OPENAI",
"model_name": model_name,
"api_key": "sk-test",
"api_base": "",
"billing_tier": billing_tier,
"rpm": 100,
"tpm": 100_000,
"litellm_params": {},
}
def _fake_openrouter_config(
*,
id: int,
model_name: str,
billing_tier: str,
router_pool_eligible: bool | None = None,
) -> dict:
"""Build a synthetic dynamic-OR config dict for router-pool tests.
Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
out. Callers can override ``router_pool_eligible`` to simulate legacy
configs or to regression-test the filter mechanics directly.
"""
if router_pool_eligible is None:
router_pool_eligible = billing_tier == "premium"
return {
"id": id,
"name": f"or-{id}",
"provider": "OPENROUTER",
"model_name": model_name,
"api_key": "sk-or-test",
"api_base": "",
"billing_tier": billing_tier,
"rpm": 20 if billing_tier == "free" else 200,
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
"litellm_params": {},
"router_pool_eligible": router_pool_eligible,
}
def _reset_router_singleton() -> None:
instance = LLMRouterService.get_instance()
instance._initialized = False
instance._router = None
instance._model_list = []
instance._premium_model_strings = set()
def test_router_pool_includes_or_premium_excludes_or_free():
"""Strategy 3: premium OR joins the pool, free OR stays out.
Dynamic OpenRouter premium entries opt into load balancing alongside
curated YAML configs. Dynamic OR free entries are intentionally kept
out because OpenRouter's free tier enforces a single account-global
quota bucket that per-deployment router accounting can't represent.
"""
_reset_router_singleton()
configs = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
_fake_openrouter_config(
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
),
_fake_openrouter_config(
id=-10_002,
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs)
pool_models = {
dep["litellm_params"]["model"]
for dep in LLMRouterService.get_instance()._model_list
}
# YAML premium + YAML free + dynamic OR premium are all in the pool.
# Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
assert pool_models == {
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openrouter/openai/gpt-4o",
}
prem = LLMRouterService.get_instance()._premium_model_strings
# YAML premium is fingerprinted under both its model_string and its
# ``base_model`` form (existing behavior we don't want to regress).
assert "openai/gpt-4o" in prem
# Dynamic OR premium is now fingerprinted as premium so pool-level
# calls through the router are billed against premium quota.
assert "openrouter/openai/gpt-4o" in prem
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
# Dynamic OR free never enters the pool, so it's never counted as premium.
assert (
LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
is False
)
def test_router_pool_filter_mechanics_respect_override():
"""The ``router_pool_eligible`` filter itself works independently of tier.
Regression guard: if a future refactor ever sets the flag False on a
premium config (e.g. for maintenance), that config MUST be skipped by
``initialize`` even though its tier is premium.
"""
_reset_router_singleton()
configs = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
_fake_openrouter_config(
id=-10_001,
model_name="openai/gpt-4o",
billing_tier="premium",
router_pool_eligible=False, # opt out despite being premium
),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs)
pool_models = {
dep["litellm_params"]["model"]
for dep in LLMRouterService.get_instance()._model_list
}
assert pool_models == {"openai/gpt-4o"}
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
def test_rebuild_refreshes_pool_after_configs_change():
_reset_router_singleton()
configs_v1 = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
]
configs_v2 = [
*configs_v1,
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs_v1)
assert len(LLMRouterService.get_instance()._model_list) == 1
# ``initialize`` should be a no-op here (already initialized).
LLMRouterService.initialize(configs_v2)
assert len(LLMRouterService.get_instance()._model_list) == 1
# ``rebuild`` must clear the guard and re-run with the new configs.
LLMRouterService.rebuild(configs_v2)
assert len(LLMRouterService.get_instance()._model_list) == 2
def test_auto_model_pin_candidates_include_dynamic_openrouter():
"""Dynamic OR configs must remain Auto-mode thread-pin candidates.
Guards against a future regression where someone adds the
``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
"""
from app.config import config
from app.services.auto_model_pin_service import _global_candidates
or_premium = _fake_openrouter_config(
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
)
or_free = _fake_openrouter_config(
id=-10_002,
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
)
original = config.GLOBAL_LLM_CONFIGS
try:
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
candidate_ids = {c["id"] for c in _global_candidates()}
assert candidate_ids == {-10_001, -10_002}
finally:
config.GLOBAL_LLM_CONFIGS = original

View file

@ -0,0 +1,380 @@
"""Unit tests for the dynamic OpenRouter integration."""
from __future__ import annotations
import pytest
from app.services.openrouter_integration_service import (
_OPENROUTER_DYNAMIC_MARKER,
_generate_configs,
_openrouter_tier,
_stable_config_id,
)
pytestmark = pytest.mark.unit
def _minimal_openrouter_model(
*,
model_id: str,
pricing: dict | None = None,
name: str | None = None,
) -> dict:
"""Return a synthetic OpenRouter /api/v1/models entry.
The real API payload includes a lot of fields; we only populate what
``_generate_configs`` actually inspects (architecture, tool support,
context, pricing, id).
"""
return {
"id": model_id,
"name": name or model_id,
"architecture": {"output_modalities": ["text"]},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
}
# ---------------------------------------------------------------------------
# _openrouter_tier
# ---------------------------------------------------------------------------
def test_openrouter_tier_free_suffix():
assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
def test_openrouter_tier_zero_pricing():
model = {
"id": "foo/bar",
"pricing": {"prompt": "0", "completion": "0"},
}
assert _openrouter_tier(model) == "free"
def test_openrouter_tier_paid():
model = {
"id": "foo/bar",
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
}
assert _openrouter_tier(model) == "premium"
def test_openrouter_tier_missing_pricing_is_premium():
assert _openrouter_tier({"id": "foo/bar"}) == "premium"
assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
# ---------------------------------------------------------------------------
# _stable_config_id
# ---------------------------------------------------------------------------
def test_stable_config_id_deterministic():
taken1: set[int] = set()
taken2: set[int] = set()
a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
assert a == b
assert a < 0
def test_stable_config_id_collision_decrements():
"""When two model_ids hash to the same slot, the second should decrement."""
taken: set[int] = set()
a = _stable_config_id("openai/gpt-4o", -10_000, taken)
# Force a collision by pre-populating ``taken`` with a slot we know will be
# picked.
taken_forced = {a}
b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
assert b != a
assert b == a - 1
assert b in taken_forced
def test_stable_config_id_different_models_different_ids():
taken: set[int] = set()
ids = {
_stable_config_id("openai/gpt-4o", -10_000, taken),
_stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
_stable_config_id("google/gemini-2.0-flash", -10_000, taken),
}
assert len(ids) == 3
def test_stable_config_id_survives_catalogue_churn():
"""Removing a model should not shift other models' IDs (the bug we fix)."""
taken1: set[int] = set()
id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
_ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
taken2: set[int] = set()
id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
assert id_a1 == id_a2
assert id_c1 == id_c2
# ---------------------------------------------------------------------------
# _generate_configs
# ---------------------------------------------------------------------------
_SETTINGS_BASE: dict = {
"api_key": "sk-or-test",
"id_offset": -10_000,
"rpm": 200,
"tpm": 1_000_000,
"free_rpm": 20,
"free_tpm": 100_000,
"anonymous_enabled_paid": False,
"anonymous_enabled_free": True,
"quota_reserve_tokens": 4000,
}
def test_generate_configs_respects_tier():
"""Premium OR models opt into the router pool; free OR models stay out.
Strategy-3 split: premium participates in LiteLLM Router load balancing,
free stays excluded because OpenRouter enforces a shared global free-tier
bucket that per-deployment router accounting can't represent.
"""
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
_minimal_openrouter_model(
model_id="meta-llama/llama-3.3-70b-instruct:free",
pricing={"prompt": "0", "completion": "0"},
),
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
by_model = {c["model_name"]: c for c in cfgs}
paid = by_model["openai/gpt-4o"]
assert paid["billing_tier"] == "premium"
assert paid["rpm"] == 200
assert paid["tpm"] == 1_000_000
assert paid["anonymous_enabled"] is False
assert paid["router_pool_eligible"] is True
assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
assert free["billing_tier"] == "free"
assert free["rpm"] == 20
assert free["tpm"] == 100_000
assert free["anonymous_enabled"] is True
assert free["router_pool_eligible"] is False
def test_generate_configs_excludes_upstream_openrouter_free_router():
"""OpenRouter's own ``openrouter/free`` meta-router must never become a card.
The upstream API returns this as a first-class zero-priced model, so
without an explicit blocklist entry it would slip through every other
filter (text output, tool calling, 200k context, non-Amazon) and land
in the selector as a duplicate of the concrete ``:free`` cards. The
exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
"""
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
_minimal_openrouter_model(
model_id="openrouter/free",
pricing={"prompt": "0", "completion": "0"},
),
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openrouter/free" not in model_names
assert "openai/gpt-4o" in model_names
def test_generate_configs_drops_non_text_and_non_tool_models():
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
{ # image-output model
"id": "openai/dall-e",
"architecture": {"output_modalities": ["image"]},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.01", "completion": "0.01"},
},
{ # text but no tool calling
"id": "openai/completion-only",
"architecture": {"output_modalities": ["text"]},
"supported_parameters": [],
"context_length": 200_000,
"pricing": {"prompt": "0.01", "completion": "0.01"},
},
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
model_names = [c["model_name"] for c in cfgs]
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names
# ---------------------------------------------------------------------------
# _generate_image_gen_configs / _generate_vision_llm_configs
# ---------------------------------------------------------------------------
def test_generate_image_gen_configs_filters_by_image_output():
"""Only models with ``output_modalities`` containing ``image`` are emitted.
Tool-calling and context filters are intentionally NOT applied image
generation has nothing to do with tool calls and context windows.
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
# Pure image-gen model (small context, no tools — should still emit).
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Multi-modal: text+image output (should still emit).
{
"id": "google/gemini-2.5-flash-image",
"architecture": {"output_modalities": ["text", "image"]},
"context_length": 1_000_000,
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
},
# Pure text model — must NOT emit.
{
"id": "openai/gpt-4o",
"architecture": {"output_modalities": ["text"]},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
]
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openai/gpt-image-1" in model_names
assert "google/gemini-2.5-flash-image" in model_names
assert "openai/gpt-4o" not in model_names
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
for c in cfgs:
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't 404 against an inherited Azure endpoint.
assert c["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_image_gen_configs_assigns_image_id_offset():
"""Image configs use a different id_offset (-20000) so their negative
IDs don't collide with chat configs (-10000) or vision configs (-30000).
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
}
]
# Don't pass image_id_offset → use the module default (-20000).
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
assert all(c["id"] < -20_000 + 1 for c in cfgs)
assert all(c["id"] > -29_000_000 for c in cfgs)
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
"""Vision LLMs must accept image input AND emit text — pure image-gen
(no text out) and text-only (no image in) models are excluded.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
# GPT-4o: vision LLM (image in, text out) — must emit.
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
# Pure image generator — image *output*, no text out. Must NOT emit.
{
"id": "openai/gpt-image-1",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["image"],
},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Pure text model (no image in). Must NOT emit.
{
"id": "anthropic/claude-3-haiku",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"context_length": 200_000,
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
},
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
names = {c["model_name"] for c in cfgs}
assert names == {"openai/gpt-4o"}
cfg = cfgs[0]
assert cfg["billing_tier"] == "premium"
# Pricing carried inline so pricing_registration can register vision
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
# is cleared.
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
# Defense-in-depth: emit the OpenRouter base URL at source so a
# downstream call site that forgets ``resolve_api_base`` still
# doesn't inherit an Azure endpoint.
assert cfg["api_base"] == "https://openrouter.ai/api/v1"
def test_generate_vision_llm_configs_drops_chat_only_filters():
"""A small-context vision model that doesn't advertise tool calling is
still a valid vision LLM for "describe this image" prompts. The chat
filters (``supports_tool_calling``, ``has_sufficient_context``) must
NOT be applied to vision emission.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
{
"id": "tiny/vision-mini",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": [], # no tools
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
}
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
assert len(cfgs) == 1
assert cfgs[0]["model_name"] == "tiny/vision-mini"

View file

@ -0,0 +1,108 @@
"""Tests for deprecated-key warnings and back-compat in
``load_openrouter_integration_settings``.
"""
from __future__ import annotations
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
def _write_yaml(tmp_path: Path, body: str) -> Path:
cfg_dir = tmp_path / "app" / "config"
cfg_dir.mkdir(parents=True)
cfg_path = cfg_dir / "global_llm_config.yaml"
cfg_path.write_text(body, encoding="utf-8")
return cfg_path
def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
billing_tier: "premium"
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
captured = capsys.readouterr().out
assert settings is not None
assert "billing_tier is deprecated" in captured
def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
anonymous_enabled: true
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
captured = capsys.readouterr().out
assert settings is not None
assert settings["anonymous_enabled_paid"] is True
assert settings["anonymous_enabled_free"] is True
assert "anonymous_enabled is" in captured
assert "deprecated" in captured
def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
"""If both legacy and new keys are present, new keys win (setdefault)."""
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
anonymous_enabled: true
anonymous_enabled_paid: false
anonymous_enabled_free: false
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
capsys.readouterr()
assert settings is not None
assert settings["anonymous_enabled_paid"] is False
assert settings["anonymous_enabled_free"] is False
def test_disabled_integration_returns_none(monkeypatch, tmp_path):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: false
api_key: "sk-or-test"
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
assert load_openrouter_integration_settings() is None

View file

@ -0,0 +1,331 @@
"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
from __future__ import annotations
from typing import Any
import pytest
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
from app.services.quality_score import (
_HEALTH_FAIL_RATIO_FALLBACK,
)
pytestmark = pytest.mark.unit
def _or_cfg(
*,
cid: int,
model_name: str,
tier: str = "premium",
static_score: int = 50,
) -> dict:
return {
"id": cid,
"provider": "OPENROUTER",
"model_name": model_name,
"billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C",
"quality_score_static": static_score,
"quality_score_health": None,
"quality_score": static_score,
"health_gated": False,
}
class _StubResponse:
def __init__(self, *, payload: dict, status_code: int = 200):
self._payload = payload
self.status_code = status_code
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
def json(self) -> dict:
return self._payload
class _StubAsyncClient:
"""Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
def __init__(self, responder):
self._responder = responder
self.requests: list[str] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
self.requests.append(url)
return self._responder(url)
def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
"""Replace ``httpx.AsyncClient`` for the duration of the test."""
client = _StubAsyncClient(responder)
monkeypatch.setattr(
"app.services.openrouter_integration_service.httpx.AsyncClient",
lambda *_args, **_kwargs: client,
)
return client
def _healthy_payload() -> dict:
return {
"data": {
"endpoints": [
{
"status": 0,
"uptime_last_30m": 0.99,
"uptime_last_1d": 0.995,
"uptime_last_5m": 0.99,
}
]
}
}
def _unhealthy_payload() -> dict:
return {
"data": {
"endpoints": [
{
"status": 0,
"uptime_last_30m": 0.55,
"uptime_last_1d": 0.62,
"uptime_last_5m": 0.50,
}
]
}
}
# ---------------------------------------------------------------------------
# Bounded fan-out + happy path
# ---------------------------------------------------------------------------
async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
_or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
]
def responder(url: str) -> _StubResponse:
if "anthropic" in url:
return _StubResponse(payload=_healthy_payload())
return _StubResponse(payload=_unhealthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {"api_key": ""}
await service._enrich_health(cfgs)
healthy = next(c for c in cfgs if c["id"] == -1)
gated = next(c for c in cfgs if c["id"] == -2)
assert healthy["health_gated"] is False
assert healthy["quality_score_health"] is not None
assert healthy["quality_score"] >= healthy["quality_score_static"]
assert gated["health_gated"] is True
assert gated["quality_score"] == gated["quality_score_static"]
async def test_enrich_health_only_touches_or_provider(monkeypatch):
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
yaml_cfg = {
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score_static": 80,
"quality_score": 80,
"health_gated": False,
}
or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
requests: list[str] = []
def responder(url: str) -> _StubResponse:
requests.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health([yaml_cfg, or_cfg])
assert all("anthropic/claude-haiku" in r for r in requests)
# YAML cfg is untouched.
assert yaml_cfg["quality_score"] == 80
assert yaml_cfg["health_gated"] is False
# ---------------------------------------------------------------------------
# Failure ratio fallback
# ---------------------------------------------------------------------------
async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
monkeypatch,
):
"""If >= 25% of fetches fail, keep last-good cache instead of writing
partial data."""
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
_or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
_or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
_or_cfg(cid=-4, model_name="venice/something", static_score=50),
]
service = OpenRouterIntegrationService()
service._settings = {}
# Pre-seed last-good cache with a known-healthy snapshot.
service._health_cache = {
"anthropic/claude-haiku": {"gated": False, "score": 95.0},
}
def all_fail(_url: str) -> _StubResponse:
return _StubResponse(payload={}, status_code=500)
_patch_async_client(monkeypatch, all_fail)
await service._enrich_health(cfgs)
# Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
assert cached_hit["quality_score_health"] == 95.0
assert cached_hit["health_gated"] is False
# Confirm the threshold constant we're testing against is real.
assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
monkeypatch,
):
"""If a fetch fails and there's no last-good cache, the cfg keeps its
static-only ``quality_score`` and is *not* gated by default."""
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
]
def fail(_url: str) -> _StubResponse:
return _StubResponse(payload={}, status_code=500)
_patch_async_client(monkeypatch, fail)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health(cfgs)
cfg = cfgs[0]
assert cfg["health_gated"] is False
assert cfg["quality_score"] == cfg["quality_score_static"]
assert cfg["quality_score_health"] is None
# ---------------------------------------------------------------------------
# Last-good cache: success populates, next failure reuses
# ---------------------------------------------------------------------------
async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
monkeypatch,
):
cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
service = OpenRouterIntegrationService()
service._settings = {}
def healthy(_url: str) -> _StubResponse:
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, healthy)
await service._enrich_health([cfg])
assert "anthropic/claude-haiku" in service._health_cache
cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
assert cached_score is not None
# Next cycle: enough other healthy cfgs so failure ratio stays below
# the 25% threshold even when this one fails individually.
other_cfgs = [
_or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
for i in range(10)
]
cfg["quality_score_health"] = None
cfg["quality_score"] = cfg["quality_score_static"]
def mixed(url: str) -> _StubResponse:
if "anthropic" in url:
return _StubResponse(payload={}, status_code=500)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, mixed)
await service._enrich_health([cfg, *other_cfgs])
assert cfg["quality_score_health"] == cached_score
assert cfg["health_gated"] is False
# ---------------------------------------------------------------------------
# Bounded fan-out: respects top-N caps
# ---------------------------------------------------------------------------
async def test_enrich_health_bounds_premium_fanout(monkeypatch):
"""Top-N premium cap is honoured even when many cfgs are present."""
from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
cfgs = [
_or_cfg(
cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
)
for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
]
seen: list[str] = []
def responder(url: str) -> _StubResponse:
seen.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health(cfgs)
assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
yaml_cfg: dict[str, Any] = {
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"billing_tier": "premium",
}
requests: list[str] = []
def responder(url: str) -> _StubResponse:
requests.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health([yaml_cfg])
assert requests == []

View file

@ -0,0 +1,447 @@
"""Pricing registration unit tests.
The pricing-registration module is what makes ``response_cost`` populate
correctly for OpenRouter dynamic models and operator-defined Azure
deployments both of which LiteLLM doesn't natively know about. The tests
exercise:
* The alias generators emit every shape that LiteLLM's cost-callback might
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
``provider/base_model``, ``provider/model_name``, plus the special
``azure_openai`` ``azure`` normalisation).
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
with the right alias set and pricing values per provider.
* Configs without a resolvable pair of cost values are skipped never
registered as zero, since that would override pricing LiteLLM might
already know natively.
"""
from __future__ import annotations
from typing import Any
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Alias generators
# ---------------------------------------------------------------------------
def test_openrouter_alias_set_includes_prefixed_and_bare():
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
assert aliases == [
"openrouter/anthropic/claude-3-5-sonnet",
"anthropic/claude-3-5-sonnet",
]
def test_openrouter_alias_set_dedupes():
"""If the model id is already prefixed with ``openrouter/``, the alias
set must not contain duplicates that would re-register the same key
twice.
"""
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("openrouter/foo")
# The bare and prefixed variants compute to the same string here, so we
# at minimum require uniqueness.
assert len(aliases) == len(set(aliases))
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
"""``azure_openai`` (our YAML provider slug) must register under
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
(which uses provider ``azure``) finds the pricing too.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="AZURE_OPENAI",
model_name="gpt-5.4",
base_model="gpt-5.4",
)
assert "gpt-5.4" in aliases
assert "azure_openai/gpt-5.4" in aliases
assert "azure/gpt-5.4" in aliases
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
"""When ``model_name`` differs from ``base_model`` (operator labelled a
deployment), both must appear in the alias set since either may surface
in callbacks depending on the call path.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="OPENAI",
model_name="my-deployment-label",
base_model="gpt-4o",
)
assert "gpt-4o" in aliases
assert "openai/gpt-4o" in aliases
assert "my-deployment-label" in aliases
assert "openai/my-deployment-label" in aliases
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="",
model_name="foo",
base_model="bar",
)
assert "bar" in aliases
assert "foo" in aliases
assert all("/" not in a for a in aliases)
# ---------------------------------------------------------------------------
# register_pricing_from_global_configs
# ---------------------------------------------------------------------------
class _RegistrationSpy:
"""Captures the dicts passed to ``litellm.register_model``.
Many calls may go through; we just record them all and let tests assert
against the union.
"""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def __call__(self, payload: dict[str, Any]) -> None:
self.calls.append(payload)
@property
def all_keys(self) -> set[str]:
keys: set[str] = set()
for payload in self.calls:
keys.update(payload.keys())
return keys
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
spy = _RegistrationSpy()
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
spy,
raising=False,
)
return spy
def _patch_openrouter_pricing(
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
) -> None:
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
class _Stub:
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
return mapping
class _StubService:
@classmethod
def is_initialized(cls) -> bool:
return True
@classmethod
def get_instance(cls) -> _Stub:
return _Stub()
monkeypatch.setattr(
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
_StubService,
raising=False,
)
def test_openrouter_models_register_under_aliases(monkeypatch):
"""An OpenRouter config whose ``model_name`` is in the cached raw
pricing map is registered under both ``openrouter/X`` and bare ``X``.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
# Costs are float-converted from the raw OpenRouter strings.
payload = spy.calls[0]
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"input_cost_per_token"
] == pytest.approx(3e-6)
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"output_cost_per_token"
] == pytest.approx(15e-6)
assert (
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
== "openrouter"
)
def test_yaml_override_registers_under_alias_set(monkeypatch):
"""Operator-declared ``input_cost_per_token`` /
``output_cost_per_token`` on a YAML config registers under every
alias the YAML alias generator produces including the ``azure/``
normalisation for ``azure_openai`` providers.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5.4",
"litellm_params": {
"base_model": "gpt-5.4",
"input_cost_per_token": 2e-6,
"output_cost_per_token": 8e-6,
},
}
],
)
register_pricing_from_global_configs()
keys = spy.all_keys
assert "gpt-5.4" in keys
assert "azure_openai/gpt-5.4" in keys
assert "azure/gpt-5.4" in keys
payload = spy.calls[0]
entry = payload["gpt-5.4"]
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
assert entry["litellm_provider"] == "azure"
def test_no_override_means_no_registration(monkeypatch):
"""A YAML config that *omits* both pricing fields must NOT be registered
registering as zero would override LiteLLM's native pricing for the
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENAI",
"model_name": "gpt-4o",
"litellm_params": {"base_model": "gpt-4o"},
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
configured model (network blip during refresh, model added later, etc.),
we skip it rather than registering zero pricing.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_register_continues_after_individual_failure(monkeypatch, caplog):
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
must not abort registration of the remaining configs.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
successful_calls: list[dict[str, Any]] = []
def _maybe_fail(payload: dict[str, Any]) -> None:
if any(k in failing_keys for k in payload):
raise RuntimeError("boom")
successful_calls.append(payload)
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
_maybe_fail,
raising=False,
)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
},
{
"id": 2,
"provider": "OPENAI",
"model_name": "custom-deployment",
"litellm_params": {
"base_model": "custom-deployment",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 2e-6,
},
},
],
)
register_pricing_from_global_configs()
# The good config still registered.
assert any("custom-deployment" in payload for payload in successful_calls)
def test_vision_configs_registered_with_chat_shape(monkeypatch):
"""``register_pricing_from_global_configs`` walks
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
calls (during indexing) bill correctly. Vision configs use the same
chat-shape token prices, but image-gen pricing is intentionally NOT
registered here (handled via ``response_cost`` in LiteLLM).
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
)
# No chat configs — only vision. Proves the vision walk is a separate
# iteration, not piggy-backed on the chat list.
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"billing_tier": "premium",
"input_cost_per_token": 5e-6,
"output_cost_per_token": 15e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/openai/gpt-4o" in spy.all_keys
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
assert payload_value["mode"] == "chat"
assert payload_value["litellm_provider"] == "openrouter"
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
"""If the OpenRouter pricing cache misses a vision model (different
catalogue surface), the vision walk falls back to inline
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash",
"billing_tier": "premium",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 4e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys

View file

@ -0,0 +1,107 @@
"""Unit tests for the shared ``api_base`` resolver.
The cascade exists so vision and image-gen call sites can't silently
inherit ``litellm.api_base`` (commonly set by ``AZURE_OPENAI_ENDPOINT``)
when an OpenRouter / Groq / etc. config ships an empty string. See
``provider_api_base`` module docstring for the original repro
(OpenRouter image-gen 404-ing against an Azure endpoint).
"""
from __future__ import annotations
import pytest
from app.services.provider_api_base import (
PROVIDER_DEFAULT_API_BASE,
PROVIDER_KEY_DEFAULT_API_BASE,
resolve_api_base,
)
pytestmark = pytest.mark.unit
def test_config_value_wins_over_defaults():
"""A non-empty config value is always returned verbatim, even when the
provider has a default the operator gets the last word."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="https://my-openrouter-mirror.example.com/v1",
)
assert result == "https://my-openrouter-mirror.example.com/v1"
def test_provider_key_default_when_config_missing():
"""``DEEPSEEK`` shares the ``openai`` LiteLLM prefix but has its own
base URL the provider-key map must take precedence over the prefix
map so DeepSeek requests don't go to OpenAI."""
result = resolve_api_base(
provider="DEEPSEEK",
provider_prefix="openai",
config_api_base=None,
)
assert result == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_provider_prefix_default_when_no_key_default():
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=None,
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_unknown_provider_returns_none():
"""When neither map matches we return ``None`` so the caller can let
LiteLLM apply its own provider-integration default (Azure deployment
URL, custom-provider URL, etc.)."""
result = resolve_api_base(
provider="SOMETHING_NEW",
provider_prefix="something_new",
config_api_base=None,
)
assert result is None
def test_empty_string_config_treated_as_missing():
"""The original bug: OpenRouter dynamic configs ship ``api_base=""``
and downstream call sites use ``if cfg.get("api_base"):`` empty
strings are falsy in Python but the cascade has to step in anyway."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base="",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_whitespace_only_config_treated_as_missing():
"""A config value of ``" "`` is a configuration mistake — treat it
as missing instead of forwarding whitespace to LiteLLM (which would
almost certainly 404)."""
result = resolve_api_base(
provider="OPENROUTER",
provider_prefix="openrouter",
config_api_base=" ",
)
assert result == PROVIDER_DEFAULT_API_BASE["openrouter"]
def test_provider_case_insensitive():
"""Some call sites pass the provider lowercase (DB enum value), others
uppercase (YAML key). Both must resolve."""
upper = resolve_api_base(
provider="DEEPSEEK", provider_prefix="openai", config_api_base=None
)
lower = resolve_api_base(
provider="deepseek", provider_prefix="openai", config_api_base=None
)
assert upper == lower == PROVIDER_KEY_DEFAULT_API_BASE["DEEPSEEK"]
def test_all_inputs_none_returns_none():
assert (
resolve_api_base(provider=None, provider_prefix=None, config_api_base=None)
is None
)

View file

@ -0,0 +1,244 @@
"""Unit tests for the shared chat-image capability resolver.
Two resolvers, two intents:
- ``derive_supports_image_input`` best-effort True for the catalog and
selector. Default-allow on unknown / unmapped models. The streaming
task safety net never sees this value directly.
- ``is_known_text_only_chat_model`` strict opt-out for the safety net.
Returns True only when LiteLLM's model map *explicitly* sets
``supports_vision=False``. Anything else (missing key, exception,
True) returns False so the request flows through to the provider.
"""
from __future__ import annotations
import pytest
from app.services.provider_capabilities import (
derive_supports_image_input,
is_known_text_only_chat_model,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# derive_supports_image_input — OpenRouter modalities path (authoritative)
# ---------------------------------------------------------------------------
def test_or_modalities_with_image_returns_true():
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="openai/gpt-4o",
openrouter_input_modalities=["text", "image"],
)
is True
)
def test_or_modalities_text_only_returns_false():
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="deepseek/deepseek-v3.2-exp",
openrouter_input_modalities=["text"],
)
is False
)
def test_or_modalities_empty_list_returns_false():
"""OR explicitly publishing an empty modality list is a definitive
'no inputs at all' signal treat as False rather than falling back
to LiteLLM."""
assert (
derive_supports_image_input(
provider="OPENROUTER",
model_name="weird/empty-modalities",
openrouter_input_modalities=[],
)
is False
)
def test_or_modalities_none_falls_through_to_litellm():
"""``None`` (missing key) is *not* a definitive signal — fall through
to LiteLLM. Using ``openai/gpt-4o`` which is in LiteLLM's map."""
assert (
derive_supports_image_input(
provider="OPENAI",
model_name="gpt-4o",
openrouter_input_modalities=None,
)
is True
)
# ---------------------------------------------------------------------------
# derive_supports_image_input — LiteLLM model-map path
# ---------------------------------------------------------------------------
def test_litellm_known_vision_model_returns_true():
assert (
derive_supports_image_input(
provider="OPENAI",
model_name="gpt-4o",
)
is True
)
def test_litellm_base_model_wins_over_model_name():
"""Azure-style entries pass model_name=deployment_id and put the
canonical sku in litellm_params.base_model. The resolver must
consult base_model first or the deployment id (which LiteLLM
doesn't know) would shadow the real capability."""
assert (
derive_supports_image_input(
provider="AZURE_OPENAI",
model_name="my-azure-deployment-id",
base_model="gpt-4o",
)
is True
)
def test_litellm_unknown_model_default_allows():
"""Default-allow on unknown — the safety net is the actual block."""
assert (
derive_supports_image_input(
provider="CUSTOM",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
is True
)
def test_litellm_known_text_only_returns_false():
"""A model that LiteLLM explicitly knows is text-only resolves to
False even via the catalog resolver. ``deepseek-chat`` (the
DeepSeek-V3 chat sku) is in the map without supports_vision and
LiteLLM's `supports_vision` returns False."""
# Sanity: confirm the helper's negative path. We use a small model
# known not to support vision per the map.
result = derive_supports_image_input(
provider="DEEPSEEK",
model_name="deepseek-chat",
)
# We accept either False (LiteLLM said explicit no) or True
# (default-allow if the entry isn't mapped on this version) — the
# invariant is that the resolver never *raises* on a known-text-only
# provider/model. The behaviour-binding assertion lives in
# ``test_is_known_text_only_chat_model_explicit_false`` below.
assert isinstance(result, bool)
# ---------------------------------------------------------------------------
# is_known_text_only_chat_model — strict opt-out semantics
# ---------------------------------------------------------------------------
def test_is_known_text_only_returns_false_for_vision_model():
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_is_known_text_only_returns_false_for_unknown_model():
"""Strict opt-out: missing from the map ≠ text-only. The safety net
must NOT fire for an unmapped model that's the regression we're
fixing."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
model_name="brand-new-model-x9-unmapped",
custom_provider="brand_new_proxy",
)
is False
)
def test_is_known_text_only_returns_false_when_lookup_raises(monkeypatch):
"""LiteLLM's ``get_model_info`` raises freely on parse errors. The
helper swallows the exception and returns False so the safety net
doesn't fire on a transient lookup failure."""
import app.services.provider_capabilities as pc
def _raise(**_kwargs):
raise ValueError("intentional test failure")
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_is_known_text_only_returns_true_on_explicit_false(monkeypatch):
"""Stub LiteLLM's ``get_model_info`` to return an explicit False so
we exercise the opt-out path deterministically. Using a stub keeps
the test stable across LiteLLM map updates."""
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"supports_vision": False, "max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is True
)
def test_is_known_text_only_returns_false_on_supports_vision_true(monkeypatch):
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"supports_vision": True}
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is False
)
def test_is_known_text_only_returns_false_on_missing_key(monkeypatch):
"""A model entry without ``supports_vision`` at all is treated as
'unknown' strict opt-out means False."""
import app.services.provider_capabilities as pc
def _info(**_kwargs):
return {"max_input_tokens": 8192} # no supports_vision
monkeypatch.setattr(pc.litellm, "get_model_info", _info)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="any-model",
)
is False
)

View file

@ -0,0 +1,345 @@
"""Unit tests for the Auto (Fastest) quality scoring module."""
from __future__ import annotations
import time
import pytest
from app.services.quality_score import (
_HEALTH_GATE_UPTIME_PCT,
_OPERATOR_TRUST_BONUS,
aggregate_health,
capabilities_signal,
context_signal,
created_recency_signal,
pricing_band,
slug_penalty,
static_score_or,
static_score_yaml,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# created_recency_signal
# ---------------------------------------------------------------------------
def test_created_recency_signal_recent_model_scores_high():
now = 1_750_000_000 # ~mid-2025
one_month_ago = now - (30 * 86_400)
assert created_recency_signal(one_month_ago, now) == 20
def test_created_recency_signal_old_model_scores_zero():
now = 1_750_000_000
five_years_ago = now - (5 * 365 * 86_400)
assert created_recency_signal(five_years_ago, now) == 0
def test_created_recency_signal_missing_timestamp_is_neutral():
now = 1_750_000_000
assert created_recency_signal(None, now) == 0
assert created_recency_signal(0, now) == 0
def test_created_recency_signal_monotonic_decay():
now = 1_750_000_000
scores = [
created_recency_signal(now - days * 86_400, now)
for days in (30, 120, 300, 500, 700, 1000, 1500)
]
assert scores == sorted(scores, reverse=True)
# ---------------------------------------------------------------------------
# pricing_band
# ---------------------------------------------------------------------------
def test_pricing_band_free_returns_zero():
assert pricing_band("0", "0") == 0
assert pricing_band(0.0, 0.0) == 0
assert pricing_band(None, None) == 0
def test_pricing_band_handles_unparseable():
assert pricing_band("not-a-number", "0") == 0
assert pricing_band({}, []) == 0 # type: ignore[arg-type]
def test_pricing_band_premium_tiers_increase_with_price():
cheap = pricing_band("0.0000003", "0.0000005")
mid = pricing_band("0.000003", "0.000015")
flagship = pricing_band("0.00001", "0.00005")
assert 0 < cheap < mid < flagship
# ---------------------------------------------------------------------------
# context_signal
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"ctx,expected",
[
(1_500_000, 10),
(1_000_000, 10),
(500_000, 8),
(200_000, 6),
(128_000, 4),
(100_000, 2),
(50_000, 0),
(0, 0),
(None, 0),
],
)
def test_context_signal_bands(ctx, expected):
assert context_signal(ctx) == expected
# ---------------------------------------------------------------------------
# capabilities_signal
# ---------------------------------------------------------------------------
def test_capabilities_signal_caps_at_five():
assert (
capabilities_signal(
["tools", "structured_outputs", "reasoning", "include_reasoning"]
)
<= 5
)
def test_capabilities_signal_tools_only():
assert capabilities_signal(["tools"]) == 2
def test_capabilities_signal_empty():
assert capabilities_signal(None) == 0
assert capabilities_signal([]) == 0
# ---------------------------------------------------------------------------
# slug_penalty
# ---------------------------------------------------------------------------
def test_slug_penalty_demotes_tiny_models():
assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
assert slug_penalty("liquid/lfm-7b") < 0
assert slug_penalty("google/gemma-3n-e4b-it") < 0
def test_slug_penalty_skips_capable_mini_nano_lite_models():
"""Critical Option C+ regression: don't penalise modern frontier
models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
assert slug_penalty("openai/gpt-5-mini") == 0
assert slug_penalty("openai/gpt-5-nano") == 0
assert slug_penalty("google/gemini-2.5-flash-lite") == 0
assert slug_penalty("anthropic/claude-haiku-4.5") == 0
def test_slug_penalty_demotes_legacy_variants():
assert slug_penalty("openai/o1-preview") < 0
assert slug_penalty("foo/bar-base") < 0
assert slug_penalty("foo/bar-distill") < 0
def test_slug_penalty_empty_input():
assert slug_penalty("") == 0
# ---------------------------------------------------------------------------
# static_score_or
# ---------------------------------------------------------------------------
def _or_model(
*,
model_id: str,
created: int | None = None,
prompt: str = "0.000003",
completion: str = "0.000015",
context: int = 200_000,
params: list[str] | None = None,
) -> dict:
return {
"id": model_id,
"created": created,
"pricing": {"prompt": prompt, "completion": completion},
"context_length": context,
"supported_parameters": params if params is not None else ["tools"],
}
def test_static_score_or_frontier_premium_beats_free_tiny():
now = 1_750_000_000
frontier = _or_model(
model_id="openai/gpt-5",
created=now - (60 * 86_400),
prompt="0.000005",
completion="0.000020",
context=400_000,
params=["tools", "structured_outputs", "reasoning"],
)
tiny_free = _or_model(
model_id="meta-llama/llama-3.2-1b-instruct:free",
created=now - (5 * 365 * 86_400),
prompt="0",
completion="0",
context=128_000,
params=["tools"],
)
assert static_score_or(frontier, now_ts=now) > static_score_or(
tiny_free, now_ts=now
)
def test_static_score_or_score_is_clamped_0_to_100():
now = int(time.time())
score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
assert 0 <= score <= 100
def test_static_score_or_unknown_provider_is_neutral_not_zero():
now = int(time.time())
score = static_score_or(
_or_model(model_id="some-new-lab/some-model"),
now_ts=now,
)
assert score > 0
def test_static_score_or_recent_release_beats_year_old_same_provider():
now = 1_750_000_000
fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
# ---------------------------------------------------------------------------
# static_score_yaml
# ---------------------------------------------------------------------------
def test_static_score_yaml_includes_operator_bonus():
cfg = {
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
score = static_score_yaml(cfg)
assert score >= _OPERATOR_TRUST_BONUS
def test_static_score_yaml_unknown_provider_still_carries_bonus():
cfg = {
"provider": "SOME_NEW_PROVIDER",
"model_name": "weird-model",
}
score = static_score_yaml(cfg)
assert score >= _OPERATOR_TRUST_BONUS
def test_static_score_yaml_clamped_0_to_100():
cfg = {
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
assert 0 <= static_score_yaml(cfg) <= 100
# ---------------------------------------------------------------------------
# aggregate_health
# ---------------------------------------------------------------------------
def test_aggregate_health_gates_when_uptime_below_threshold():
"""Live data showed Venice-routed cfgs at 53-68%; this guards that the
90% gate excludes them."""
venice_endpoints = [
{
"status": 0,
"uptime_last_30m": 0.55,
"uptime_last_1d": 0.60,
"uptime_last_5m": 0.50,
},
{
"status": 0,
"uptime_last_30m": 0.65,
"uptime_last_1d": 0.68,
"uptime_last_5m": 0.62,
},
]
gated, score = aggregate_health(venice_endpoints)
assert gated is True
assert score is None
def test_aggregate_health_passes_for_healthy_provider():
healthy = [
{
"status": 0,
"uptime_last_30m": 0.99,
"uptime_last_1d": 0.995,
"uptime_last_5m": 0.99,
},
]
gated, score = aggregate_health(healthy)
assert gated is False
assert score is not None
assert score >= _HEALTH_GATE_UPTIME_PCT
def test_aggregate_health_picks_best_endpoint_across_multiple():
"""Multi-endpoint aggregation should reward the best non-null uptime."""
mixed = [
{"status": 0, "uptime_last_30m": 0.55},
{"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
]
gated, score = aggregate_health(mixed)
assert gated is False
assert score is not None
def test_aggregate_health_empty_endpoints_gated():
gated, score = aggregate_health([])
assert gated is True
assert score is None
def test_aggregate_health_no_status_zero_gated():
"""Even with high uptime, no OK status means the cfg is broken upstream."""
endpoints = [
{"status": 1, "uptime_last_30m": 0.99},
{"status": 2, "uptime_last_30m": 0.98},
]
gated, score = aggregate_health(endpoints)
assert gated is True
assert score is None
def test_aggregate_health_all_uptime_null_gated():
endpoints = [
{"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
]
gated, score = aggregate_health(endpoints)
assert gated is True
assert score is None
def test_aggregate_health_pct_normalisation():
"""OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
percentages. Both should reach the same gate decision."""
fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
g1, s1 = aggregate_health(fraction_form)
g2, s2 = aggregate_health(pct_form)
assert g1 == g2 == False # noqa: E712
assert s1 is not None and s2 is not None
assert abs(s1 - s2) < 0.5

View file

@ -0,0 +1,157 @@
"""Unit tests for ``QuotaCheckedVisionLLM``.
Validates that:
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
enforcement) and forwards the inner LLM's response on success.
* The wrapper proxies non-overridden attributes to the inner LLM
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
still work without quota gating (they're not used in indexing today).
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
bubbles it up the ETL pipeline catches that and falls back to OCR.
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
class _FakeInnerLLM:
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
def __init__(self, response: Any = "OCR'd content") -> None:
self._response = response
self.ainvoke_calls: list[Any] = []
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
self.ainvoke_calls.append(input)
return self._response
def some_other_method(self, x: int) -> int:
return x * 2
@contextlib.asynccontextmanager
async def _passthrough_billable_call(**_kwargs):
"""Stand-in for billable_call that always allows the call to run."""
class _Acc:
total_cost_micros = 0
total_prompt_tokens = 0
total_completion_tokens = 0
grand_total = 0
calls: list[Any] = []
def per_message_summary(self) -> dict[str, dict[str, int]]:
return {}
yield _Acc()
@pytest.mark.asyncio
async def test_ainvoke_routes_through_billable_call(monkeypatch):
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
captured_kwargs: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _spy_billable_call(**kwargs):
captured_kwargs.append(kwargs)
async with _passthrough_billable_call() as acc:
yield acc
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_spy_billable_call,
raising=False,
)
inner = _FakeInnerLLM(response="A red apple on a white table")
user_id = uuid4()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=user_id,
search_space_id=99,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
result = await wrapper.ainvoke([{"text": "what is this?"}])
assert result == "A red apple on a white table"
assert len(inner.ainvoke_calls) == 1
assert len(captured_kwargs) == 1
bc_kwargs = captured_kwargs[0]
assert bc_kwargs["user_id"] == user_id
assert bc_kwargs["search_space_id"] == 99
assert bc_kwargs["billing_tier"] == "premium"
assert bc_kwargs["base_model"] == "openai/gpt-4o"
assert bc_kwargs["quota_reserve_tokens"] == 4000
assert bc_kwargs["usage_type"] == "vision_extraction"
@pytest.mark.asyncio
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
from app.services.billable_calls import QuotaInsufficientError
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
@contextlib.asynccontextmanager
async def _denying_billable_call(**_kwargs):
raise QuotaInsufficientError(
usage_type="vision_extraction",
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield # unreachable but required for asynccontextmanager type
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_denying_billable_call,
raising=False,
)
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
with pytest.raises(QuotaInsufficientError):
await wrapper.ainvoke([{"text": "x"}])
# Inner LLM never ran on a denied reservation.
assert inner.ainvoke_calls == []
@pytest.mark.asyncio
async def test_proxies_non_overridden_attributes_to_inner():
"""``__getattr__`` forwards anything not on the proxy itself, so any
method we didn't explicitly override (``invoke``, ``astream``,
``with_structured_output``, etc.) still works just without quota
gating, which is fine because the indexer only ever calls ainvoke.
"""
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
# ``some_other_method`` is on the inner only.
assert wrapper.some_other_method(7) == 14

View file

@ -0,0 +1,281 @@
"""Unit tests for the chat-catalog ``supports_image_input`` capability flag.
Capability is sourced from two places, in order of preference:
1. ``architecture.input_modalities`` for dynamic OpenRouter chat configs
(authoritative OpenRouter publishes per-model modalities directly).
2. LiteLLM's authoritative model map (``litellm.supports_vision``) for
YAML / BYOK configs that don't carry an explicit operator override.
The catalog default is *True* (conservative-allow): an unknown / unmapped
model is not pre-judged. The streaming-task safety net
(``is_known_text_only_chat_model``) is the only place a False actually
blocks a request and it requires LiteLLM to *explicitly* mark the model
as text-only.
"""
from __future__ import annotations
import pytest
from app.services.openrouter_integration_service import (
_OPENROUTER_DYNAMIC_MARKER,
_generate_configs,
_supports_image_input,
)
pytestmark = pytest.mark.unit
_SETTINGS_BASE: dict = {
"api_key": "sk-or-test",
"id_offset": -10_000,
"rpm": 200,
"tpm": 1_000_000,
"free_rpm": 20,
"free_tpm": 100_000,
"anonymous_enabled_paid": False,
"anonymous_enabled_free": True,
"quota_reserve_tokens": 4000,
}
# ---------------------------------------------------------------------------
# _supports_image_input helper (OpenRouter modalities)
# ---------------------------------------------------------------------------
def test_supports_image_input_true_for_multimodal():
assert (
_supports_image_input(
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
}
)
is True
)
def test_supports_image_input_false_for_text_only():
"""The exact failure mode the safety net guards against — DeepSeek V3
is a text-in/text-out model and would 404 if forwarded image_url."""
assert (
_supports_image_input(
{
"id": "deepseek/deepseek-v3.2-exp",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
}
)
is False
)
def test_supports_image_input_false_when_modalities_missing():
"""Defensive: missing architecture is treated as text-only at the
OpenRouter helper level. The wider catalog resolver
(`derive_supports_image_input`) only consults modalities when they
are non-empty, otherwise it falls back to LiteLLM."""
assert _supports_image_input({"id": "weird/model"}) is False
assert _supports_image_input({"id": "weird/model", "architecture": {}}) is False
assert (
_supports_image_input(
{"id": "weird/model", "architecture": {"input_modalities": None}}
)
is False
)
# ---------------------------------------------------------------------------
# _generate_configs threads the flag onto every emitted chat config
# ---------------------------------------------------------------------------
def test_generate_configs_emits_supports_image_input():
raw = [
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
{
"id": "deepseek/deepseek-v3.2-exp",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
},
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
by_model = {c["model_name"]: c for c in cfgs}
gpt = by_model["openai/gpt-4o"]
assert gpt["supports_image_input"] is True
assert gpt[_OPENROUTER_DYNAMIC_MARKER] is True
deepseek = by_model["deepseek/deepseek-v3.2-exp"]
assert deepseek["supports_image_input"] is False
assert deepseek[_OPENROUTER_DYNAMIC_MARKER] is True
# ---------------------------------------------------------------------------
# YAML loader: defer to derive_supports_image_input on unannotated entries
# ---------------------------------------------------------------------------
def test_yaml_loader_resolves_unannotated_vision_model_to_true(tmp_path, monkeypatch):
"""The regression case: an Azure GPT-5.x YAML entry without a
``supports_image_input`` override should resolve to True via LiteLLM's
model map (which says ``supports_vision: true``). Previously this
defaulted to False, blocking every image turn for vision-capable
YAML configs."""
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -2
name: Azure GPT-4o
provider: AZURE_OPENAI
model_name: gpt-4o
api_key: sk-test
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
assert configs[0]["supports_image_input"] is True
def test_yaml_loader_respects_explicit_supports_image_input(tmp_path, monkeypatch):
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -1
name: GPT-4o
provider: OPENAI
model_name: gpt-4o
api_key: sk-test
supports_image_input: false
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
# Operator override always wins, even against LiteLLM's True.
assert configs[0]["supports_image_input"] is False
def test_yaml_loader_unknown_model_default_allows(tmp_path, monkeypatch):
"""Unknown / unmapped model in YAML: default-allow. The streaming
safety net (which requires an explicit-False from LiteLLM) is the
only place a real block happens, so we don't lock the user out of
a freshly added third-party entry the catalog can't introspect."""
yaml_dir = tmp_path / "app" / "config"
yaml_dir.mkdir(parents=True)
(yaml_dir / "global_llm_config.yaml").write_text(
"""
global_llm_configs:
- id: -1
name: Some Brand New Model
provider: CUSTOM
custom_provider: brand_new_proxy
model_name: brand-new-model-x9
api_key: sk-test
""",
encoding="utf-8",
)
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
configs = config_module.load_global_llm_configs()
assert len(configs) == 1
assert configs[0]["supports_image_input"] is True
# ---------------------------------------------------------------------------
# AgentConfig threads the flag through both YAML and Auto / BYOK
# ---------------------------------------------------------------------------
def test_agent_config_from_yaml_explicit_overrides_resolver():
from app.agents.new_chat.llm_config import AgentConfig
cfg_text_only = AgentConfig.from_yaml_config(
{
"id": -1,
"name": "Text Only Override",
"provider": "openai",
"model_name": "gpt-4o", # Capable per LiteLLM, but operator says no.
"api_key": "sk-test",
"supports_image_input": False,
}
)
cfg_explicit_vision = AgentConfig.from_yaml_config(
{
"id": -2,
"name": "GPT-4o",
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
"supports_image_input": True,
}
)
assert cfg_text_only.supports_image_input is False
assert cfg_explicit_vision.supports_image_input is True
def test_agent_config_from_yaml_unannotated_uses_resolver():
"""Without an explicit YAML key, AgentConfig defers to the catalog
resolver for ``gpt-4o`` LiteLLM's map says supports_vision=True."""
from app.agents.new_chat.llm_config import AgentConfig
cfg = AgentConfig.from_yaml_config(
{
"id": -1,
"name": "GPT-4o (no override)",
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "sk-test",
}
)
assert cfg.supports_image_input is True
def test_agent_config_auto_mode_supports_image_input():
"""Auto routes across the pool. We optimistically allow image input
so users can keep their selection on Auto with a vision-capable
deployment somewhere in the pool. The router's own `allowed_fails`
handles non-vision deployments via fallback."""
from app.agents.new_chat.llm_config import AgentConfig
auto = AgentConfig.from_auto_mode()
assert auto.supports_image_input is True

View file

@ -0,0 +1,515 @@
"""Cost-based premium quota unit tests.
Covers the USD-micro behaviour added in migration 140:
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
calls in a turn used as the debit amount when ``agent_config.is_premium``
is true, regardless of which underlying model produced each call. This
preserves the prior "premium turn → all calls in turn count" rule from the
token-based system.
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
clamps to a sane floor when pricing is unknown, and respects the
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
can't lock the whole balance on one call.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# TurnTokenAccumulator — premium-turn debit semantics
# ---------------------------------------------------------------------------
def test_total_cost_micros_sums_premium_and_free_calls():
"""A premium turn that also called a free sub-agent debits the union.
The plan deliberately preserved the existing "premium turn → all calls
count" behaviour because per-call premium filtering relied on
``LLMRouterService._premium_model_strings`` which only covers router-pool
deployments. ``total_cost_micros`` therefore must include free-model
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
call's actual provider cost.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
# Premium model (e.g. claude-opus): non-zero cost.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=1200,
completion_tokens=400,
total_tokens=1600,
cost_micros=12_345,
)
# Free sub-agent (e.g. title-gen on a free model): zero cost.
acc.add(
model="gpt-4o-mini",
prompt_tokens=120,
completion_tokens=20,
total_tokens=140,
cost_micros=0,
)
# A second premium-priced call within the same turn.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=800,
completion_tokens=200,
total_tokens=1000,
cost_micros=7_500,
)
assert acc.total_cost_micros == 12_345 + 0 + 7_500
# Token totals stay correct so the FE display path still works.
assert acc.grand_total == 1600 + 140 + 1000
def test_total_cost_micros_zero_when_no_calls():
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
assert acc.total_cost_micros == 0
assert acc.grand_total == 0
def test_per_message_summary_groups_cost_by_model():
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
SSE ``model_breakdown`` payload reports actual USD spend per provider.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_micros=4_000,
)
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=200,
completion_tokens=100,
total_tokens=300,
cost_micros=8_000,
)
acc.add(
model="gpt-4o-mini",
prompt_tokens=50,
completion_tokens=10,
total_tokens=60,
cost_micros=200,
)
summary = acc.per_message_summary()
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
assert summary["gpt-4o-mini"]["cost_micros"] == 200
def test_serialized_calls_includes_cost_micros():
"""``serialized_calls`` is what flows into the SSE ``call_details``
payload; cost_micros must be present on each entry so the FE message-info
dropdown can render per-call USD.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="m",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=42,
)
serialized = acc.serialized_calls()
assert serialized == [
{
"model": "m",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost_micros": 42,
"call_kind": "chat",
}
]
# ---------------------------------------------------------------------------
# estimate_call_reserve_micros — sizing and clamping
# ---------------------------------------------------------------------------
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
helper falls back to the 100-micro floor small enough that a user with
$0.0001 left can still send a tiny request, but non-zero so we still gate
against an empty balance.
"""
import litellm
from app.services import token_quota_service
def _raise(_name):
raise KeyError("unknown")
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="nonexistent-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
assert micros == 100
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
"""LiteLLM may *return* a model with both cost-per-token fields at 0
(pricing not yet registered). The helper must not multiply 0 x tokens
and end up reserving 0 it must clamp to the floor.
"""
import litellm
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
raising=False,
)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="some-pending-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
def test_reserve_scales_with_model_cost(monkeypatch):
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
some small artificial cap that was the bug the plan called out.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 15e-6,
"output_cost_per_token": 75e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="claude-3-opus",
quota_reserve_tokens=4000,
)
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
assert micros == 360_000
def test_reserve_clamps_to_max_ceiling(monkeypatch):
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
nominally compute to $4 = 4_000_000 micros. The ceiling
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
can't lock the user's whole balance on one call.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-3,
"output_cost_per_token": 0,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="oops-misconfigured",
quota_reserve_tokens=4000,
)
assert micros == 1_000_000
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
so anonymous-style configs still reserve the operator-tunable default.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-6,
"output_cost_per_token": 1e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=None
)
== 4000
)
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=0
)
== 4000
)
# ---------------------------------------------------------------------------
# TokenTrackingCallback — image vs chat usage shape
# ---------------------------------------------------------------------------
class _FakeImageUsage:
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
def __init__(
self,
input_tokens: int = 0,
output_tokens: int = 0,
total_tokens: int | None = None,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
if total_tokens is not None:
self.total_tokens = total_tokens
class _FakeImageResponse:
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
``type(...).__name__`` probe routes to the image branch.
"""
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
self.usage = usage
if response_cost is not None:
self._hidden_params = {"response_cost": response_cost}
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
# the callback. We can't simply name the class ``ImageResponse`` because
# the test runner sometimes imports test modules in surprising ways and
# we want to be explicit.
_FakeImageResponse.__name__ = "ImageResponse"
class _FakeChatUsage:
def __init__(self, prompt: int, completion: int):
self.prompt_tokens = prompt
self.completion_tokens = completion
self.total_tokens = prompt + completion
class _FakeChatResponse:
def __init__(self, usage: _FakeChatUsage):
self.usage = usage
@pytest.mark.asyncio
async def test_callback_reads_image_usage_input_output_tokens():
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
prompt_tokens/completion_tokens which is the chat shape.
"""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
response_cost=0.04, # $0.04 per image
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 42
assert call.completion_tokens == 8
assert call.total_tokens == 50
# 0.04 USD = 40_000 micros
assert call.cost_micros == 40_000
assert call.call_kind == "image_generation"
@pytest.mark.asyncio
async def test_callback_chat_path_unchanged():
"""Chat responses must still read prompt_tokens/completion_tokens."""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={
"model": "openrouter/anthropic/claude-3-5-sonnet",
"response_cost": 0.0036,
},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 120
assert call.completion_tokens == 30
assert call.total_tokens == 150
assert call.cost_micros == 3_600
assert call.call_kind == "chat"
@pytest.mark.asyncio
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
"""When OpenRouter omits ``usage.cost`` LiteLLM's
``default_image_cost_calculator`` raises. The defensive image branch in
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
chat-shaped and would raise too) it returns 0 with a WARNING log.
"""
import litellm
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
# Force completion_cost to raise the same way OpenRouter image-gen fails.
def _boom(*_args, **_kwargs):
raise ValueError("model_cost: missing entry for openrouter image model")
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
# And make sure cost_per_token is NEVER called for the image path —
# if it were, our ``is_image=True`` branch is broken.
cost_per_token_calls: list = []
def _record_cost_per_token(**kwargs):
cost_per_token_calls.append(kwargs)
return (0.0, 0.0)
monkeypatch.setattr(
litellm, "cost_per_token", _record_cost_per_token, raising=False
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
assert acc.calls[0].cost_micros == 0
assert acc.calls[0].call_kind == "image_generation"
# The image branch must short-circuit before cost_per_token.
assert cost_per_token_calls == []
# ---------------------------------------------------------------------------
# scoped_turn — ContextVar reset semantics (issue B)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_scoped_turn_restores_outer_accumulator():
"""``scoped_turn`` must restore the previous ContextVar value on exit
so a per-call wrapper inside an outer chat turn doesn't leak its
accumulator outward (which would cause double-debit at chat-turn exit).
"""
from app.services.token_tracking_service import (
get_current_accumulator,
scoped_turn,
start_turn,
)
outer = start_turn()
assert get_current_accumulator() is outer
async with scoped_turn() as inner:
assert get_current_accumulator() is inner
assert inner is not outer
inner.add(
model="x",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=5,
)
# After exit the outer accumulator is restored unchanged.
assert get_current_accumulator() is outer
assert outer.total_cost_micros == 0
assert len(outer.calls) == 0
# The inner accumulator captured the call but didn't bleed into outer.
assert inner.total_cost_micros == 5
@pytest.mark.asyncio
async def test_scoped_turn_resets_to_none_when_no_outer():
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
indexing job) must leave the ContextVar at ``None`` on exit so the
next *unrelated* request starts clean.
"""
from app.services.token_tracking_service import (
_turn_accumulator,
get_current_accumulator,
scoped_turn,
)
# ContextVar default is None for a fresh test isolated context. We
# simulate "no outer" explicitly to be robust against test order.
token = _turn_accumulator.set(None)
try:
assert get_current_accumulator() is None
async with scoped_turn() as acc:
assert get_current_accumulator() is acc
assert get_current_accumulator() is None
finally:
_turn_accumulator.reset(token)

View file

@ -0,0 +1,89 @@
"""Defense-in-depth: vision-LLM resolution must not leak ``api_base``
defaults from ``litellm.api_base`` either.
Vision shares the same shape as image-gen global YAML / OpenRouter
dynamic configs ship ``api_base=""`` and the pre-fix ``get_vision_llm``
call sites would silently drop the empty string and inherit
``AZURE_OPENAI_ENDPOINT``. ``ChatLiteLLM(...)`` doesn't 404 on
construction so we test the kwargs we hand to it instead.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_get_vision_llm_global_openrouter_sets_api_base():
"""Global negative-ID branch: an OpenRouter vision config with
``api_base=""`` must end up calling ``SanitizedChatLiteLLM`` with
``api_base="https://openrouter.ai/api/v1"`` never an empty string,
never silently absent."""
from app.services import llm_service
cfg = {
"id": -30_001,
"name": "GPT-4o Vision (OpenRouter)",
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"api_key": "sk-or-test",
"api_base": "",
"api_version": None,
"litellm_params": {},
"billing_tier": "free",
}
search_space = MagicMock()
search_space.id = 1
search_space.user_id = "user-x"
search_space.vision_llm_config_id = cfg["id"]
session = AsyncMock()
scalars = MagicMock()
scalars.first.return_value = search_space
result = MagicMock()
result.scalars.return_value = scalars
session.execute.return_value = result
captured: dict = {}
class FakeSanitized:
def __init__(self, **kwargs):
captured.update(kwargs)
with (
patch(
"app.services.vision_llm_router_service.get_global_vision_llm_config",
return_value=cfg,
),
patch(
"app.agents.new_chat.llm_config.SanitizedChatLiteLLM",
new=FakeSanitized,
),
):
await llm_service.get_vision_llm(session=session, search_space_id=1)
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
assert captured["model"] == "openrouter/openai/gpt-4o"
def test_vision_router_deployment_sets_api_base_when_config_empty():
"""Auto-mode vision router: deployments are fed to ``litellm.Router``,
so the resolver has to apply at deployment construction time too."""
from app.services.vision_llm_router_service import VisionLLMRouterService
deployment = VisionLLMRouterService._config_to_deployment(
{
"model_name": "openai/gpt-4o",
"provider": "OPENROUTER",
"api_key": "sk-or-test",
"api_base": "",
}
)
assert deployment is not None
assert deployment["litellm_params"]["api_base"] == "https://openrouter.ai/api/v1"
assert deployment["litellm_params"]["model"] == "openrouter/openai/gpt-4o"

View file

@ -0,0 +1,526 @@
"""Unit tests for ``AssistantContentBuilder``.
Pins the in-memory ``ContentPart[]`` projection so the JSONB the server
persists matches what the frontend renders live (see
``surfsense_web/lib/chat/streaming-state.ts``). Every test asserts both
the structural shape of ``snapshot()`` and that the snapshot is
``json.dumps``-safe (the streaming finally block writes it directly to
``new_chat_messages.content`` without an explicit serialization round
trip).
"""
from __future__ import annotations
import json
import pytest
from app.tasks.chat.content_builder import AssistantContentBuilder
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _assert_jsonb_safe(parts: list[dict]) -> None:
"""Sanity check: any snapshot must round-trip through ``json.dumps``."""
serialized = json.dumps(parts)
assert json.loads(serialized) == parts
# ---------------------------------------------------------------------------
# Text turns
# ---------------------------------------------------------------------------
class TestTextOnly:
def test_single_text_block_collapses_consecutive_deltas(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "Hello")
b.on_text_delta("text-1", " ")
b.on_text_delta("text-1", "world")
b.on_text_end("text-1")
snap = b.snapshot()
assert snap == [{"type": "text", "text": "Hello world"}]
assert not b.is_empty()
_assert_jsonb_safe(snap)
def test_empty_text_start_end_pair_leaves_no_part(self):
# Mirrors the FE: a text-start without any deltas should
# not materialise an empty ``{"type":"text","text":""}`` part.
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_end("text-1")
assert b.snapshot() == []
assert b.is_empty()
def test_text_after_text_end_starts_fresh_part(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "first")
b.on_text_end("text-1")
b.on_text_start("text-2")
b.on_text_delta("text-2", "second")
b.on_text_end("text-2")
snap = b.snapshot()
assert snap == [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
class TestReasoningThenText:
def test_reasoning_followed_by_text_yields_two_parts_in_order(self):
b = AssistantContentBuilder()
b.on_reasoning_start("r-1")
b.on_reasoning_delta("r-1", "Considering options...")
b.on_reasoning_end("r-1")
b.on_text_start("text-1")
b.on_text_delta("text-1", "The answer is 42.")
b.on_text_end("text-1")
snap = b.snapshot()
assert snap == [
{"type": "reasoning", "text": "Considering options..."},
{"type": "text", "text": "The answer is 42."},
]
_assert_jsonb_safe(snap)
def test_text_delta_after_reasoning_implicitly_closes_reasoning(self):
# Mirrors FE ``appendText``: a text delta arriving while a
# reasoning part is "active" still produces a fresh text
# part, never appends into the reasoning block.
b = AssistantContentBuilder()
b.on_reasoning_start("r-1")
b.on_reasoning_delta("r-1", "thinking")
# No explicit reasoning_end — text delta should close it.
b.on_text_delta("text-1", "answer")
snap = b.snapshot()
assert snap == [
{"type": "reasoning", "text": "thinking"},
{"type": "text", "text": "answer"},
]
# ---------------------------------------------------------------------------
# Tool calls
# ---------------------------------------------------------------------------
class TestToolHeavyTurn:
def test_full_tool_lifecycle_produces_complete_tool_call_part(self):
b = AssistantContentBuilder()
# Some narration before the tool fires.
b.on_text_start("text-1")
b.on_text_delta("text-1", "Searching...")
b.on_text_end("text-1")
b.on_tool_input_start(
ui_id="call_run123",
tool_name="web_search",
langchain_tool_call_id="lc_tool_abc",
)
b.on_tool_input_delta("call_run123", '{"query":')
b.on_tool_input_delta("call_run123", '"surfsense"}')
b.on_tool_input_available(
ui_id="call_run123",
tool_name="web_search",
args={"query": "surfsense"},
langchain_tool_call_id="lc_tool_abc",
)
b.on_tool_output_available(
ui_id="call_run123",
output={"status": "completed", "citations": {}},
langchain_tool_call_id="lc_tool_abc",
)
snap = b.snapshot()
assert snap[0] == {"type": "text", "text": "Searching..."}
tool_part = snap[1]
assert tool_part["type"] == "tool-call"
assert tool_part["toolCallId"] == "call_run123"
assert tool_part["toolName"] == "web_search"
assert tool_part["args"] == {"query": "surfsense"}
# ``argsText`` is the pretty-printed final JSON, not the raw
# streaming buffer (FE ``stream-pipeline.ts:128``).
assert tool_part["argsText"] == json.dumps(
{"query": "surfsense"}, indent=2, ensure_ascii=False
)
assert tool_part["langchainToolCallId"] == "lc_tool_abc"
assert tool_part["result"] == {"status": "completed", "citations": {}}
_assert_jsonb_safe(snap)
def test_tool_input_available_without_prior_start_creates_card(self):
# Legacy / parity_v2-OFF path: tool-input-available may be
# emitted without a prior tool-input-start (no streamed
# tool_call_chunks). The card should still be created.
b = AssistantContentBuilder()
b.on_tool_input_available(
ui_id="call_run42",
tool_name="grep",
args={"pattern": "TODO"},
langchain_tool_call_id="lc_x",
)
b.on_tool_output_available(
ui_id="call_run42",
output={"matches": 3},
langchain_tool_call_id="lc_x",
)
snap = b.snapshot()
assert len(snap) == 1
part = snap[0]
assert part["type"] == "tool-call"
assert part["toolCallId"] == "call_run42"
assert part["args"] == {"pattern": "TODO"}
assert part["langchainToolCallId"] == "lc_x"
assert part["result"] == {"matches": 3}
def test_tool_input_start_idempotent_for_same_ui_id(self):
# parity_v2: tool-input-start can fire from BOTH the chunk
# registration path AND the canonical ``on_tool_start`` path.
# The second call must not create a duplicate part.
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", "lc_x")
b.on_tool_input_start("call_x", "ls", "lc_x")
snap = b.snapshot()
assert len(snap) == 1
def test_tool_input_delta_without_prior_start_is_silently_dropped(self):
b = AssistantContentBuilder()
b.on_tool_input_delta("call_unknown", '{"orphan": "delta"}')
assert b.snapshot() == []
def test_langchain_tool_call_id_backfills_only_when_absent(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", "lc_first")
# Late event must NOT clobber an already-set lc id.
b.on_tool_input_start("call_x", "ls", "lc_late")
snap = b.snapshot()
assert snap[0]["langchainToolCallId"] == "lc_first"
def test_args_text_streaming_buffer_reflects_concatenation(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "save_doc", "lc_y")
b.on_tool_input_delta("call_x", '{"title":')
b.on_tool_input_delta("call_x", '"Hi"}')
# Snapshot mid-stream should see the partial buffer (the FE
# tolerates invalid JSON and renders it as-is).
mid = b.snapshot()
assert mid[0]["argsText"] == '{"title":"Hi"}'
# Then tool-input-available replaces with pretty-printed.
b.on_tool_input_available(
"call_x",
"save_doc",
{"title": "Hi"},
"lc_y",
)
final = b.snapshot()
assert final[0]["argsText"] == json.dumps(
{"title": "Hi"}, indent=2, ensure_ascii=False
)
# ---------------------------------------------------------------------------
# Thinking steps & separators
# ---------------------------------------------------------------------------
class TestThinkingSteps:
def test_first_thinking_step_unshifts_singleton_to_index_zero(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "Hello")
b.on_text_end("text-1")
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item-a"])
snap = b.snapshot()
# Singleton goes to index 0 (FE ``updateThinkingSteps`` unshift).
assert snap[0]["type"] == "data-thinking-steps"
assert snap[0]["data"]["steps"] == [
{
"id": "step-1",
"title": "Analyzing",
"status": "in_progress",
"items": ["item-a"],
}
]
assert snap[1] == {"type": "text", "text": "Hello"}
def test_subsequent_thinking_steps_mutate_the_singleton_in_place(self):
b = AssistantContentBuilder()
b.on_thinking_step("step-1", "Analyzing", "in_progress", [])
b.on_thinking_step("step-2", "Searching", "in_progress", ["q"])
b.on_thinking_step("step-1", "Analyzing", "completed", ["done"])
snap = b.snapshot()
assert len([p for p in snap if p["type"] == "data-thinking-steps"]) == 1
steps = snap[0]["data"]["steps"]
assert len(steps) == 2
assert steps[0]["id"] == "step-1"
assert steps[0]["status"] == "completed"
assert steps[0]["items"] == ["done"]
assert steps[1]["id"] == "step-2"
def test_thinking_step_with_text_continues_appending_to_text(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "first")
# Thinking step inserts at index 0, bumps text idx from 0 to 1.
b.on_thinking_step("step-1", "Working", "in_progress", [])
b.on_text_delta("text-1", " second")
snap = b.snapshot()
text_parts = [p for p in snap if p["type"] == "text"]
assert text_parts == [{"type": "text", "text": "first second"}]
def test_thinking_step_without_id_is_dropped(self):
b = AssistantContentBuilder()
b.on_thinking_step("", "noop", "in_progress", None)
assert b.snapshot() == []
assert b.is_empty()
class TestStepSeparators:
def test_separator_no_op_before_any_content(self):
b = AssistantContentBuilder()
b.on_step_separator()
assert b.snapshot() == []
def test_separator_after_text_appends_with_step_index_zero(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "first")
b.on_text_end("text-1")
b.on_step_separator()
snap = b.snapshot()
assert snap[-1] == {
"type": "data-step-separator",
"data": {"stepIndex": 0},
}
def test_consecutive_separators_collapse_to_one(self):
b = AssistantContentBuilder()
b.on_text_delta("text-1", "x")
b.on_step_separator()
b.on_step_separator() # No-op: previous part is already a separator.
snap = b.snapshot()
assert sum(1 for p in snap if p["type"] == "data-step-separator") == 1
def test_step_index_increments_across_separators(self):
b = AssistantContentBuilder()
b.on_text_delta("text-1", "a")
b.on_step_separator()
b.on_text_delta("text-2", "b")
b.on_step_separator()
snap = b.snapshot()
seps = [p for p in snap if p["type"] == "data-step-separator"]
assert [s["data"]["stepIndex"] for s in seps] == [0, 1]
# ---------------------------------------------------------------------------
# Interruption handling
# ---------------------------------------------------------------------------
class TestMarkInterrupted:
def test_running_tool_calls_get_state_aborted(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_a", "ls", "lc_a")
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
# No tool-output-available — simulates client disconnect mid-tool.
b.mark_interrupted()
snap = b.snapshot()
assert snap[0]["state"] == "aborted"
assert "result" not in snap[0]
def test_completed_tool_calls_are_not_marked_aborted(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_a", "ls", "lc_a")
b.on_tool_input_available("call_a", "ls", {"path": "/"}, "lc_a")
b.on_tool_output_available("call_a", {"files": []}, "lc_a")
b.mark_interrupted()
snap = b.snapshot()
assert "state" not in snap[0]
assert snap[0]["result"] == {"files": []}
def test_open_text_block_keeps_accumulated_content(self):
b = AssistantContentBuilder()
b.on_text_start("text-1")
b.on_text_delta("text-1", "partial")
# No on_text_end — disconnect mid-stream.
b.mark_interrupted()
snap = b.snapshot()
assert snap == [{"type": "text", "text": "partial"}]
# ---------------------------------------------------------------------------
# is_empty / snapshot semantics
# ---------------------------------------------------------------------------
class TestIsEmpty:
def test_fresh_builder_is_empty(self):
assert AssistantContentBuilder().is_empty()
def test_text_part_breaks_emptiness(self):
b = AssistantContentBuilder()
b.on_text_delta("text-1", "x")
assert not b.is_empty()
def test_tool_call_breaks_emptiness(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", None)
assert not b.is_empty()
def test_thinking_step_alone_does_not_break_emptiness(self):
# Mirrors the "status marker fallback" semantic: a turn that
# only emitted a thinking step before being interrupted should
# still be treated as empty for finalize_assistant_turn's
# status-marker substitution.
b = AssistantContentBuilder()
b.on_thinking_step("step-1", "Working", "in_progress", [])
assert b.is_empty()
def test_step_separator_alone_does_not_break_emptiness(self):
b = AssistantContentBuilder()
# Force a separator (it would normally no-op without content,
# but we simulate the underlying state to verify is_empty is
# not fooled by a stray separator).
b.parts.append({"type": "data-step-separator", "data": {"stepIndex": 0}})
assert b.is_empty()
class TestSnapshotSemantics:
def test_snapshot_is_deep_copied_so_mutations_do_not_leak(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_x", "ls", "lc_x")
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
snap = b.snapshot()
# Mutate the returned snapshot — original should be untouched.
snap[0]["args"]["mutated"] = True
snap[0]["state"] = "tampered"
again = b.snapshot()
assert "mutated" not in again[0]["args"]
assert "state" not in again[0]
def test_snapshot_round_trips_through_json(self):
b = AssistantContentBuilder()
b.on_thinking_step("step-1", "Analyzing", "in_progress", ["item"])
b.on_text_delta("text-1", "answer")
b.on_tool_input_start("call_x", "ls", "lc_x")
b.on_tool_input_available("call_x", "ls", {"path": "/"}, "lc_x")
b.on_tool_output_available("call_x", {"files": ["a.txt"]}, "lc_x")
b.on_step_separator()
snap = b.snapshot()
encoded = json.dumps(snap)
assert json.loads(encoded) == snap
class TestStats:
"""``stats()`` is the perf-log handle for [PERF] [stream_*]
finalize_payload lines. Pin the schema so an ops dashboard can
rely on these keys being present and meaningful.
"""
def test_fresh_builder_reports_all_zeros(self):
b = AssistantContentBuilder()
s = b.stats()
assert s == {
"parts": 0,
"bytes": 2, # ``[]`` is two bytes
"text": 0,
"reasoning": 0,
"tool_calls": 0,
"tool_calls_completed": 0,
"tool_calls_aborted": 0,
"thinking_step_parts": 0,
"step_separators": 0,
}
def test_counts_each_part_type_independently(self):
b = AssistantContentBuilder()
b.on_text_start("t1")
b.on_text_delta("t1", "hi")
b.on_text_end("t1")
b.on_reasoning_start("r1")
b.on_reasoning_delta("r1", "thinking")
b.on_reasoning_end("r1")
b.on_thinking_step("step-1", "Analyzing", "completed", ["item"])
b.on_step_separator()
b.on_tool_input_start("call_done", "ls", "lc_done")
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
b.on_tool_input_start("call_running", "rm", "lc_running")
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
s = b.stats()
assert s["text"] == 1
assert s["reasoning"] == 1
assert s["tool_calls"] == 2
assert s["tool_calls_completed"] == 1
assert s["tool_calls_aborted"] == 0
assert s["thinking_step_parts"] == 1
assert s["step_separators"] == 1
assert s["parts"] == sum(
[
s["text"],
s["reasoning"],
s["tool_calls"],
s["thinking_step_parts"],
s["step_separators"],
]
)
assert s["bytes"] > 0
def test_mark_interrupted_flips_running_calls_to_aborted_in_stats(self):
b = AssistantContentBuilder()
b.on_tool_input_start("call_done", "ls", "lc_done")
b.on_tool_input_available("call_done", "ls", {}, "lc_done")
b.on_tool_output_available("call_done", {"ok": True}, "lc_done")
b.on_tool_input_start("call_running", "rm", "lc_running")
b.on_tool_input_available("call_running", "rm", {}, "lc_running")
# Pre-interrupt: one completed, one still running (no result).
pre = b.stats()
assert pre["tool_calls_completed"] == 1
assert pre["tool_calls_aborted"] == 0
b.mark_interrupted()
post = b.stats()
assert post["tool_calls_completed"] == 1
assert post["tool_calls_aborted"] == 1
assert post["tool_calls"] == 2
def test_bytes_reflects_jsonb_payload_size(self):
# Each text-delta adds bytes monotonically — useful for catching
# an unbounded delta buffer regression in the perf signal.
b = AssistantContentBuilder()
b.on_text_start("t1")
b.on_text_delta("t1", "x" * 10)
small = b.stats()["bytes"]
b.on_text_delta("t1", "x" * 1000)
large = b.stats()["bytes"]
assert large > small + 900

View file

@ -51,22 +51,34 @@ class _FakeToolMessage:
tool_call_id: str | None = None
@dataclass
class _FakeInterrupt:
value: dict[str, Any]
@dataclass
class _FakeTask:
interrupts: tuple[_FakeInterrupt, ...] = ()
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
def __init__(self, tasks: list[Any] | None = None) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
# and empty ``tasks`` keep the post-stream interrupt check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
self.tasks: list[Any] = tasks or []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
def __init__(
self, events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> None:
self._events = events
self._state = state or _FakeAgentState()
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
@ -79,7 +91,7 @@ class _FakeAgent:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
return self._state
def _model_stream(
@ -170,11 +182,13 @@ def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
async def _drain(
events: list[dict[str, Any]], state: _FakeAgentState | None = None
) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
agent = _FakeAgent(events, state=state)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
@ -525,3 +539,31 @@ async def test_unmatched_fallback_still_attaches_lc_id(
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"
@pytest.mark.asyncio
async def test_interrupt_request_uses_task_that_contains_interrupt(
parity_v2_on: None,
) -> None:
interrupt_payload = {
"type": "calendar_event_create",
"action": {
"tool": "create_calendar_event",
"params": {"summary": "mom bday"},
},
"context": {},
}
state = _FakeAgentState(
tasks=[
_FakeTask(interrupts=()),
_FakeTask(interrupts=(_FakeInterrupt(value=interrupt_payload),)),
]
)
payloads = await _drain([], state=state)
interrupts = _of_type(payloads, "data-interrupt-request")
assert len(interrupts) == 1
assert (
interrupts[0]["data"]["action_requests"][0]["name"] == "create_calendar_event"
)

View file

@ -0,0 +1,318 @@
"""Regression tests for ``run_async_celery_task``.
These tests pin down the production bug observed on 2026-05-02 where
the video-presentation Celery task hung at ``[billable_call] finalize``
because the shared ``app.db.engine`` had pooled asyncpg connections
bound to a *previous* task's now-closed event loop. Reusing such a
connection on a fresh loop crashes inside ``pool_pre_ping`` with::
AttributeError: 'NoneType' object has no attribute 'send'
(the proactor is None because the loop is gone) and can hang forever
inside the asyncpg ``Connection._cancel`` cleanup coroutine.
The fix is ``run_async_celery_task``: a small helper that runs every
async celery task body inside a fresh event loop and disposes the
shared engine pool both before (defends against a previous task that
crashed) and after (releases connections we opened on this loop).
Tests here exercise the helper with a stub engine that records
``dispose()`` calls and panics if a coroutine produced by one loop is
awaited on another mirroring the real asyncpg behaviour.
"""
from __future__ import annotations
import asyncio
import gc
import sys
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import patch
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Stub engine that emulates the asyncpg-on-stale-loop crash
# ---------------------------------------------------------------------------
class _StaleLoopEngine:
"""Tiny stand-in for ``app.db.engine`` that tracks dispose() calls.
``dispose()`` is async (matches ``AsyncEngine.dispose``) and records
the running event loop id so tests can assert it ran on *each*
fresh loop.
"""
def __init__(self) -> None:
self.dispose_loop_ids: list[int] = []
async def dispose(self) -> None:
loop = asyncio.get_running_loop()
self.dispose_loop_ids.append(id(loop))
@contextmanager
def _patch_shared_engine(stub: _StaleLoopEngine) -> Iterator[None]:
"""Patch ``from app.db import engine as shared_engine`` lookup.
The helper imports lazily inside the function body, so we have to
patch the attribute on the already-loaded ``app.db`` module.
"""
import app.db as app_db
original = getattr(app_db, "engine", None)
app_db.engine = stub # type: ignore[attr-defined]
try:
yield
finally:
if original is None:
with pytest.raises(AttributeError):
_ = app_db.engine
else:
app_db.engine = original # type: ignore[attr-defined]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_runner_returns_value_and_disposes_engine_around_call() -> None:
"""Happy path: the coroutine result is returned, and the shared
engine is disposed both before and after the task body runs.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _body() -> str:
# Engine should already have been disposed once before we run.
assert len(stub.dispose_loop_ids) == 1
return "ok"
with _patch_shared_engine(stub):
result = run_async_celery_task(_body)
assert result == "ok"
# Once before the body, once after (in finally).
assert len(stub.dispose_loop_ids) == 2
# Both disposes ran on the SAME (fresh) loop the task body used.
assert stub.dispose_loop_ids[0] == stub.dispose_loop_ids[1]
def test_runner_creates_fresh_loop_per_invocation() -> None:
"""Each call must spin its own loop. Without this guarantee a
previous task's loop would be reused and the asyncpg-stale-loop
crash would never be avoided.
"""
import app.tasks.celery_tasks as celery_tasks_pkg
stub = _StaleLoopEngine()
new_loop_calls = 0
closed_loops: list[bool] = []
real_new_event_loop = asyncio.new_event_loop
def _counting_new_loop() -> asyncio.AbstractEventLoop:
nonlocal new_loop_calls
new_loop_calls += 1
loop = real_new_event_loop()
# Hook close() so we can verify each loop was closed properly
# before the next one was created.
original_close = loop.close
def _tracked_close() -> None:
closed_loops.append(True)
original_close()
loop.close = _tracked_close # type: ignore[method-assign]
return loop
async def _body() -> None:
# Loop is alive and current at body execution time.
running = asyncio.get_running_loop()
assert not running.is_closed()
with (
_patch_shared_engine(stub),
patch.object(asyncio, "new_event_loop", _counting_new_loop),
):
for _ in range(3):
celery_tasks_pkg.run_async_celery_task(_body)
assert new_loop_calls == 3
assert closed_loops == [True, True, True]
# Each invocation disposed twice (before + after).
assert len(stub.dispose_loop_ids) == 6
def test_runner_disposes_engine_even_when_body_raises() -> None:
"""Cleanup MUST run on the failure path too — otherwise stale
connections leak into the next task and cause the original hang.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
class _BoomError(RuntimeError):
pass
async def _body() -> None:
raise _BoomError("kaboom")
with _patch_shared_engine(stub), pytest.raises(_BoomError):
run_async_celery_task(_body)
assert len(stub.dispose_loop_ids) == 2 # before + after still ran
def test_runner_swallows_dispose_errors() -> None:
"""A flaky engine.dispose() must NEVER take down a celery task.
Production scenario: the very first dispose (before the body runs)
might hit a partially-initialised engine; the helper logs and
moves on. The task body still runs; the result is still returned.
"""
from app.tasks.celery_tasks import run_async_celery_task
class _AngryEngine:
def __init__(self) -> None:
self.calls = 0
async def dispose(self) -> None:
self.calls += 1
raise RuntimeError("dispose() blew up")
stub = _AngryEngine()
async def _body() -> int:
return 42
with _patch_shared_engine(stub):
assert run_async_celery_task(_body) == 42
assert stub.calls == 2 # before + after both attempted
def test_runner_propagates_value_from_async_body() -> None:
"""Sanity: pass-through of any pickleable celery return value."""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _body() -> dict[str, object]:
return {"status": "ready", "video_presentation_id": 19}
with _patch_shared_engine(stub):
out = run_async_celery_task(_body)
assert out == {"status": "ready", "video_presentation_id": 19}
def test_video_presentation_task_uses_runner_helper() -> None:
"""Defence-in-depth: confirm the celery task module imports
``run_async_celery_task``. If a future refactor inlines a
``loop = asyncio.new_event_loop(); ... loop.close()`` block again,
the original hang will return.
"""
# The module's task body should not contain a manual new_event_loop
# call — that's exactly what the helper exists to centralise.
import inspect
from app.tasks.celery_tasks import video_presentation_tasks
src = inspect.getsource(video_presentation_tasks)
assert "run_async_celery_task" in src, (
"video_presentation_tasks.py must use run_async_celery_task; "
"manual asyncio.new_event_loop() in a celery task hangs on the "
"shared SQLAlchemy pool when reused across tasks."
)
assert "asyncio.new_event_loop" not in src, (
"video_presentation_tasks.py contains a raw asyncio.new_event_loop "
"call — route every async task through run_async_celery_task to "
"avoid the stale-pool hang."
)
def test_podcast_task_uses_runner_helper() -> None:
"""Symmetric assertion for the podcast task — same root cause, same
fix, same regression risk.
"""
import inspect
from app.tasks.celery_tasks import podcast_tasks
src = inspect.getsource(podcast_tasks)
assert "run_async_celery_task" in src
assert "asyncio.new_event_loop" not in src
def test_runner_runs_shutdown_asyncgens_before_close() -> None:
"""If the task body created any async generators that didn't get
fully iterated, we must still call ``loop.shutdown_asyncgens()``
before closing otherwise we leak event-loop bound resources
that re-emerge as ``RuntimeError: Event loop is closed`` later.
"""
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
async def _agen():
try:
yield 1
yield 2
finally:
pass
async def _body() -> None:
# Iterate the agen partially, then leave it dangling — exactly
# the situation shutdown_asyncgens() is designed to clean up.
async for v in _agen():
if v == 1:
break
with _patch_shared_engine(stub):
run_async_celery_task(_body)
# By the time the helper returns, garbage collection + shutdown_asyncgens
# should have ensured no live async-gen references remain. We don't
# assert agen.closed directly (it depends on GC ordering); the real
# contract is "no warnings, no event-loop-closed errors". A successful
# second invocation proves the loop was cleaned up properly.
with _patch_shared_engine(stub):
run_async_celery_task(_body)
# Force a GC pass to surface any 'coroutine was never awaited'
# warnings that would indicate the cleanup is broken.
gc.collect()
def test_runner_uses_proactor_loop_on_windows() -> None:
"""On Windows the celery worker preselects a Proactor policy so
subprocess (ffmpeg) calls work. The helper must not silently fall
back to a Selector loop and re-break video/podcast generation.
"""
if not sys.platform.startswith("win"):
pytest.skip("Windows-specific event-loop policy assertion")
from app.tasks.celery_tasks import run_async_celery_task
stub = _StaleLoopEngine()
# Mirror the policy set at the top of every Windows celery task.
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
observed: list[str] = []
async def _body() -> None:
observed.append(type(asyncio.get_running_loop()).__name__)
with _patch_shared_engine(stub):
run_async_celery_task(_body)
assert observed == ["ProactorEventLoop"]

View file

@ -0,0 +1,388 @@
"""Unit tests for podcast Celery task billing integration.
Validates ``_generate_content_podcast`` correctly wraps
``podcaster_graph.ainvoke`` in a ``billable_call`` envelope, propagates the
search-space owner's billing decision, and degrades cleanly when the
resolver fails or premium credit is exhausted.
Coverage:
* Happy-path free config: resolver ``billable_call`` enters with
``usage_type='podcast_generation'`` and the configured reserve override,
graph runs, podcast row flips to ``READY``.
* Happy-path premium config: same wiring with ``billing_tier='premium'``.
* Quota denial: ``billable_call`` raises ``QuotaInsufficientError``
graph is *not* invoked, podcast row flips to ``FAILED``, return dict
carries ``reason='premium_quota_exhausted'``.
* Resolver failure: ``ValueError`` from the resolver podcast row flips
to ``FAILED``, return dict carries ``reason='billing_resolution_failed'``.
"""
from __future__ import annotations
import contextlib
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
def filter(self, *_args, **_kwargs):
return self
class _FakeSession:
def __init__(self, podcast):
self._podcast = podcast
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self._podcast)
async def commit(self):
self.commit_count += 1
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return None
class _FakeSessionMaker:
def __init__(self, session: _FakeSession):
self._session = session
def __call__(self):
return self._session
def _make_podcast(podcast_id: int = 7, thread_id: int = 99) -> SimpleNamespace:
"""Stand-in for a ``Podcast`` row. Importing ``PodcastStatus`` lazily
inside helpers keeps this fixture cheap."""
return SimpleNamespace(
id=podcast_id,
title="Test Podcast",
thread_id=thread_id,
status=None,
podcast_transcript=None,
file_location=None,
)
@contextlib.asynccontextmanager
async def _ok_billable_call(**kwargs):
"""Stand-in for ``billable_call`` that records its kwargs and yields a
no-op accumulator-shaped object."""
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
_CALL_LOG: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _denying_billable_call(**kwargs):
from app.services.billable_calls import QuotaInsufficientError
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover — for grammar only
@contextlib.asynccontextmanager
async def _settlement_failing_billable_call(**kwargs):
from app.services.billable_calls import BillingSettlementError
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
raise BillingSettlementError(
usage_type=kwargs.get("usage_type", "?"),
user_id=kwargs["user_id"],
cause=RuntimeError("finalize failed"),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_call_log():
_CALL_LOG.clear()
yield
_CALL_LOG.clear()
@pytest.mark.asyncio
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
"""Happy path: free billing tier still wraps the graph call so the
audit row is recorded. Verifies kwargs threading."""
from app.config import config as app_config
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=7, thread_id=99)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
assert search_space_id == 555
assert thread_id == 99
return user_id, "free", "openrouter/some-free-model"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {
"podcast_transcript": [
SimpleNamespace(speaker_id=0, dialog="Hi"),
SimpleNamespace(speaker_id=1, dialog="Hello"),
],
"final_podcast_file_path": "/tmp/podcast.wav",
}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=7,
source_content="hello world",
search_space_id=555,
user_prompt="make it short",
)
assert result["status"] == "ready"
assert result["podcast_id"] == 7
assert podcast.status == PodcastStatus.READY
assert podcast.file_location == "/tmp/podcast.wav"
assert len(_CALL_LOG) == 1
call = _CALL_LOG[0]
assert call["user_id"] == user_id
assert call["search_space_id"] == 555
assert call["billing_tier"] == "free"
assert call["base_model"] == "openrouter/some-free-model"
assert call["usage_type"] == "podcast_generation"
assert (
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS
)
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
# FK to avoid coupling Celery audit commits to an active chat transaction.
assert "thread_id" not in call
assert call["call_details"] == {
"podcast_id": 7,
"title": "Test Podcast",
"thread_id": 99,
}
assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
"""Premium resolution flows through to ``billable_call`` so the
reserve/finalize path triggers."""
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast()
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return user_id, "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
await podcast_tasks._generate_content_podcast(
podcast_id=7,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert _CALL_LOG[0]["billing_tier"] == "premium"
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
@pytest.mark.asyncio
async def test_quota_insufficient_marks_podcast_failed_and_skips_graph(monkeypatch):
"""When ``billable_call`` denies the reservation, the graph never
runs and the podcast row flips to FAILED with the documented reason
code."""
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=8)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(podcast_tasks, "billable_call", _denying_billable_call)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=8,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 8,
"reason": "premium_quota_exhausted",
}
assert podcast.status == PodcastStatus.FAILED
assert graph_invoked == [] # Graph never ran on denied reservation.
@pytest.mark.asyncio
async def test_billing_settlement_failure_marks_podcast_failed(monkeypatch):
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=10)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _fake_resolver
)
monkeypatch.setattr(
podcast_tasks, "billable_call", _settlement_failing_billable_call
)
async def _fake_graph_invoke(state, config):
return {"podcast_transcript": [], "final_podcast_file_path": "x.wav"}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=10,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 10,
"reason": "billing_settlement_failed",
}
assert podcast.status == PodcastStatus.FAILED
@pytest.mark.asyncio
async def test_resolver_failure_marks_podcast_failed(monkeypatch):
"""If the resolver raises (e.g. search-space deleted), the task fails
cleanly without invoking the graph."""
from app.db import PodcastStatus
from app.tasks.celery_tasks import podcast_tasks
podcast = _make_podcast(podcast_id=9)
session = _FakeSession(podcast)
monkeypatch.setattr(
podcast_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
raise ValueError("Search space 555 not found")
monkeypatch.setattr(
podcast_tasks, "_resolve_agent_billing_for_search_space", _failing_resolver
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(podcast_tasks.podcaster_graph, "ainvoke", _fake_graph_invoke)
result = await podcast_tasks._generate_content_podcast(
podcast_id=9,
source_content="hi",
search_space_id=555,
user_prompt=None,
)
assert result == {
"status": "failed",
"podcast_id": 9,
"reason": "billing_resolution_failed",
}
assert podcast.status == PodcastStatus.FAILED
assert graph_invoked == []

View file

@ -0,0 +1,119 @@
"""Predicate-level test for the chat streaming safety net.
The safety net in ``stream_new_chat`` rejects an image turn early with
a friendly ``MODEL_DOES_NOT_SUPPORT_IMAGE_INPUT`` SSE error when the
selected model is *known* to be text-only. The earlier round of this
work used a strict opt-in flag (``supports_image_input`` defaulting to
False on every YAML entry) which blocked vision-capable Azure GPT-5.x
deployments this is the regression we're fixing.
The new predicate is :func:`is_known_text_only_chat_model`, which
returns True only when LiteLLM's authoritative model map *explicitly*
sets ``supports_vision=False``. Anything else (vision True, missing
key, exception) returns False so the request flows through to the
provider.
We exercise the predicate directly here rather than driving the full
``stream_new_chat`` generator covering the gate in isolation keeps
the test focused on the regression while the generator's wider behavior
is exercised by the integration suite.
"""
from __future__ import annotations
import pytest
from app.services.provider_capabilities import is_known_text_only_chat_model
pytestmark = pytest.mark.unit
def test_safety_net_does_not_fire_for_azure_gpt_4o():
"""Regression: ``azure/gpt-4o`` (and the GPT-5.x variants) is
vision-capable per LiteLLM's model map. The previous round's
blanket-False default blocked it; the new predicate must NOT mark
it text-only."""
assert (
is_known_text_only_chat_model(
provider="AZURE_OPENAI",
model_name="my-azure-deployment",
base_model="gpt-4o",
)
is False
)
def test_safety_net_does_not_fire_for_unknown_model():
"""Default-pass on unknown — the safety net only blocks definitive
text-only confirmations. A freshly added third-party model that
LiteLLM doesn't know about must flow through to the provider."""
assert (
is_known_text_only_chat_model(
provider="CUSTOM",
custom_provider="brand_new_proxy",
model_name="brand-new-model-x9",
)
is False
)
def test_safety_net_does_not_fire_when_lookup_raises(monkeypatch):
"""Transient ``litellm.get_model_info`` exception ≠ block. The
helper swallows the error and treats it as 'unknown' False."""
import app.services.provider_capabilities as pc
def _raise(**_kwargs):
raise RuntimeError("intentional test failure")
monkeypatch.setattr(pc.litellm, "get_model_info", _raise)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="gpt-4o",
)
is False
)
def test_safety_net_fires_only_on_explicit_false(monkeypatch):
"""Stub LiteLLM to assert the only path that returns True is the
explicit ``supports_vision=False`` case. Anything else (True,
None, missing key) returns False from the predicate."""
import app.services.provider_capabilities as pc
def _info_explicit_false(**_kwargs):
return {"supports_vision": False, "max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_explicit_false)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="text-only-stub",
)
is True
)
def _info_true(**_kwargs):
return {"supports_vision": True}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_true)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="vision-stub",
)
is False
)
def _info_missing(**_kwargs):
return {"max_input_tokens": 8192}
monkeypatch.setattr(pc.litellm, "get_model_info", _info_missing)
assert (
is_known_text_only_chat_model(
provider="OPENAI",
model_name="missing-key-stub",
)
is False
)

View file

@ -0,0 +1,398 @@
"""Unit tests for video-presentation Celery task billing integration.
Mirrors ``test_podcast_billing.py`` for the video-presentation task.
Validates the same wrap-graph-in-billable_call pattern and ensures the
larger ``QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS`` reservation is
threaded through.
Coverage:
* Free config: graph runs, ``billable_call`` invoked with the video
reserve override.
* Premium config: same wiring with ``billing_tier='premium'``.
* Quota denial: graph not invoked, row FAILED, reason code surfaced.
* Resolver failure: row FAILED with ``billing_resolution_failed``.
"""
from __future__ import annotations
import contextlib
from types import SimpleNamespace
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
def filter(self, *_args, **_kwargs):
return self
class _FakeSession:
def __init__(self, video):
self._video = video
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self._video)
async def commit(self):
self.commit_count += 1
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return None
class _FakeSessionMaker:
def __init__(self, session: _FakeSession):
self._session = session
def __call__(self):
return self._session
def _make_video(video_id: int = 11, thread_id: int = 99) -> SimpleNamespace:
return SimpleNamespace(
id=video_id,
title="Test Presentation",
thread_id=thread_id,
status=None,
slides=None,
scene_codes=None,
)
_CALL_LOG: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _ok_billable_call(**kwargs):
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
@contextlib.asynccontextmanager
async def _denying_billable_call(**kwargs):
from app.services.billable_calls import QuotaInsufficientError
_CALL_LOG.append(kwargs)
raise QuotaInsufficientError(
usage_type=kwargs.get("usage_type", "?"),
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield SimpleNamespace() # pragma: no cover
@contextlib.asynccontextmanager
async def _settlement_failing_billable_call(**kwargs):
from app.services.billable_calls import BillingSettlementError
_CALL_LOG.append(kwargs)
yield SimpleNamespace()
raise BillingSettlementError(
usage_type=kwargs.get("usage_type", "?"),
user_id=kwargs["user_id"],
cause=RuntimeError("finalize failed"),
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_call_log():
_CALL_LOG.clear()
yield
_CALL_LOG.clear()
@pytest.mark.asyncio
async def test_billable_call_invoked_with_correct_kwargs_for_free_config(monkeypatch):
from app.config import config as app_config
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=11, thread_id=99)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
assert search_space_id == 777
assert thread_id == 99
return user_id, "free", "openrouter/some-free-model"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=11,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result["status"] == "ready"
assert result["video_presentation_id"] == 11
assert video.status == VideoPresentationStatus.READY
assert len(_CALL_LOG) == 1
call = _CALL_LOG[0]
assert call["user_id"] == user_id
assert call["search_space_id"] == 777
assert call["billing_tier"] == "free"
assert call["base_model"] == "openrouter/some-free-model"
assert call["usage_type"] == "video_presentation_generation"
assert (
call["quota_reserve_micros_override"]
== app_config.QUOTA_DEFAULT_VIDEO_PRESENTATION_RESERVE_MICROS
)
# Background artifact audit rows intentionally omit the TokenUsage.thread_id
# FK to avoid coupling Celery audit commits to an active chat transaction.
assert "thread_id" not in call
assert call["call_details"] == {
"video_presentation_id": 11,
"title": "Test Presentation",
"thread_id": 99,
}
assert callable(call["billable_session_factory"])
@pytest.mark.asyncio
async def test_billable_call_invoked_with_premium_tier(monkeypatch):
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video()
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
user_id = uuid4()
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return user_id, "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(video_presentation_tasks, "billable_call", _ok_billable_call)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
await video_presentation_tasks._generate_video_presentation(
video_presentation_id=11,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert _CALL_LOG[0]["billing_tier"] == "premium"
assert _CALL_LOG[0]["base_model"] == "gpt-5.4"
@pytest.mark.asyncio
async def test_quota_insufficient_marks_video_failed_and_skips_graph(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=12)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(
video_presentation_tasks, "billable_call", _denying_billable_call
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=12,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 12,
"reason": "premium_quota_exhausted",
}
assert video.status == VideoPresentationStatus.FAILED
assert graph_invoked == []
@pytest.mark.asyncio
async def test_billing_settlement_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=14)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _fake_resolver(sess, search_space_id, *, thread_id=None):
return uuid4(), "premium", "gpt-5.4"
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_fake_resolver,
)
monkeypatch.setattr(
video_presentation_tasks,
"billable_call",
_settlement_failing_billable_call,
)
async def _fake_graph_invoke(state, config):
return {"slides": [], "slide_audio_results": [], "slide_scene_codes": []}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=14,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 14,
"reason": "billing_settlement_failed",
}
assert video.status == VideoPresentationStatus.FAILED
@pytest.mark.asyncio
async def test_resolver_failure_marks_video_failed(monkeypatch):
from app.db import VideoPresentationStatus
from app.tasks.celery_tasks import video_presentation_tasks
video = _make_video(video_id=13)
session = _FakeSession(video)
monkeypatch.setattr(
video_presentation_tasks,
"get_celery_session_maker",
lambda: _FakeSessionMaker(session),
)
async def _failing_resolver(sess, search_space_id, *, thread_id=None):
raise ValueError("Search space 777 not found")
monkeypatch.setattr(
video_presentation_tasks,
"_resolve_agent_billing_for_search_space",
_failing_resolver,
)
graph_invoked = []
async def _fake_graph_invoke(state, config):
graph_invoked.append(True)
return {}
monkeypatch.setattr(
video_presentation_tasks.video_presentation_graph,
"ainvoke",
_fake_graph_invoke,
)
result = await video_presentation_tasks._generate_video_presentation(
video_presentation_id=13,
source_content="content",
search_space_id=777,
user_prompt=None,
)
assert result == {
"status": "failed",
"video_presentation_id": 13,
"reason": "billing_resolution_failed",
}
assert video.status == VideoPresentationStatus.FAILED
assert graph_invoked == []

View file

@ -1,9 +1,21 @@
import inspect
import json
import logging
import re
from pathlib import Path
import pytest
import app.tasks.chat.stream_new_chat as stream_new_chat_module
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
StreamResult,
_classify_stream_exception,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
_extract_resolved_file_path,
_log_chat_stream_error,
_tool_output_has_error,
)
@ -17,6 +29,39 @@ def test_tool_output_error_detection():
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
def test_extract_resolved_file_path_prefers_structured_path():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"status": "completed", "path": "/docs/note.md"},
tool_input=None,
)
== "/docs/note.md"
)
def test_extract_resolved_file_path_falls_back_to_tool_input():
assert (
_extract_resolved_file_path(
tool_name="edit_file",
tool_output={"status": "completed", "result": "updated"},
tool_input={"file_path": "/docs/edited.md"},
)
== "/docs/edited.md"
)
def test_extract_resolved_file_path_does_not_parse_result_text():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"result": "Updated file /docs/from-text.md"},
tool_input=None,
)
is None
)
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
@ -45,3 +90,507 @@ def test_contract_enforcement_local_only():
result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result)
def _extract_chat_stream_payload(record_message: str) -> dict:
prefix = "[chat_stream_error] "
assert record_message.startswith(prefix)
return json.loads(record_message[len(prefix) :])
def test_unified_chat_stream_error_log_schema(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="new",
error_kind="server_error",
error_code="SERVER_ERROR",
severity="warn",
is_expected=False,
request_id="req-123",
thread_id=101,
search_space_id=202,
user_id="user-1",
message="Error during chat: boom",
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
required_keys = {
"event",
"flow",
"error_kind",
"error_code",
"severity",
"is_expected",
"request_id",
"thread_id",
"search_space_id",
"user_id",
"message",
}
assert required_keys.issubset(payload.keys())
assert payload["event"] == "chat_stream_error"
assert payload["flow"] == "new"
assert payload["error_code"] == "SERVER_ERROR"
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id="req-premium",
thread_id=303,
search_space_id=404,
user_id="user-2",
message="Buy more tokens to continue with this model, or switch to a free model",
extra={"auto_fallback": False},
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
assert payload["event"] == "chat_stream_error"
assert payload["error_kind"] == "premium_quota_exhausted"
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
assert payload["flow"] == "resume"
assert payload["is_expected"] is True
assert payload["auto_fallback"] is False
def test_stream_error_emission_keeps_machine_error_codes():
source = inspect.getsource(stream_new_chat_module)
format_error_calls = re.findall(r"format_error\(", source)
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
# All stream paths should route through one shared terminal error emitter.
assert len(format_error_calls) == 1
assert {
"PREMIUM_QUOTA_EXHAUSTED",
"SERVER_ERROR",
}.issubset(emitted_error_codes)
assert 'flow: Literal["new", "regenerate"] = "new"' in source
assert "_emit_stream_terminal_error" in source
assert "flow=flow" in source
assert 'flow="resume"' in source
def test_stream_exception_classifies_rate_limited():
exc = Exception(
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
def test_stream_exception_classifies_openrouter_429_payload():
exc = Exception(
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
@pytest.mark.asyncio
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
"""``_preflight_llm`` is best-effort.
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
caller can drive the cooldown/repin branch.
- On any other transient failure it MUST swallow the error so the normal
stream path continues without surfacing preflight noise to the user.
"""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
class _RateLimitedError(Exception):
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
rate_calls: list[dict] = []
other_calls: list[dict] = []
async def _fake_acompletion_429(**kwargs):
rate_calls.append(kwargs)
raise _RateLimitedError("simulated 429")
async def _fake_acompletion_other(**kwargs):
other_calls.append(kwargs)
raise RuntimeError("some unrelated transient failure")
fake_llm = SimpleNamespace(
model="openrouter/google/gemma-4-31b-it:free",
api_key="test",
api_base=None,
)
import litellm # type: ignore[import-not-found]
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
with pytest.raises(_RateLimitedError):
await _preflight_llm(fake_llm)
assert len(rate_calls) == 1
assert rate_calls[0]["max_tokens"] == 1
assert rate_calls[0]["stream"] is False
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
# MUST NOT raise: non-rate-limit failures are swallowed.
await _preflight_llm(fake_llm)
assert len(other_calls) == 1
@pytest.mark.asyncio
async def test_preflight_skipped_for_auto_router_model():
"""Router-mode ``model='auto'`` has no single deployment to ping; the
LiteLLM router itself owns per-deployment rate-limit accounting, so the
preflight helper must short-circuit instead of issuing a probe."""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
# Should return without raising or making any LiteLLM call.
await _preflight_llm(fake_llm)
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_swallows_exceptions():
"""``_settle_speculative_agent_build`` MUST always return cleanly so the
caller can safely re-touch the request-scoped session afterwards.
The helper guards the parallel preflight + agent-build path: when the
speculative build is being discarded (429 or non-429 preflight failure)
we await it solely to release any in-flight ``AsyncSession`` usage
the build's outcome is irrelevant. Any exception (including
``CancelledError``) leaking out would skip the caller's recovery flow
and re-introduce the very session-concurrency hazard the helper exists
to prevent.
"""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _raises() -> None:
raise RuntimeError("speculative build crashed")
async def _succeeds() -> str:
return "agent"
async def _slow() -> None:
await asyncio.sleep(0.05)
for coro in (_raises(), _succeeds(), _slow()):
task = asyncio.create_task(coro)
await _settle_speculative_agent_build(task)
assert task.done()
@pytest.mark.asyncio
async def test_settle_speculative_agent_build_handles_already_done_task():
"""Done tasks (success or failure) must still be settled without raising."""
import asyncio
from app.tasks.chat.stream_new_chat import _settle_speculative_agent_build
async def _ok() -> str:
return "ok"
async def _bad() -> None:
raise ValueError("nope")
ok_task = asyncio.create_task(_ok())
bad_task = asyncio.create_task(_bad())
# Drive both to completion before settling.
await asyncio.sleep(0)
await asyncio.sleep(0)
await _settle_speculative_agent_build(ok_task)
await _settle_speculative_agent_build(bad_task)
assert ok_task.result() == "ok"
# ``bad_task`` exception was consumed by the settle helper; calling
# ``.exception()`` after the fact must still return the original error
# (the helper observes it but doesn't clear it).
assert isinstance(bad_task.exception(), ValueError)
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy_from_message():
exc = Exception("Thread is busy with another request")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
thread_id = "thread-cancelling-1"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "TURN_CANCELLING"
assert severity == "info"
assert is_expected is True
assert "stopping" in user_message
assert isinstance(extra, dict)
assert "retry_after_ms" in extra
def test_premium_classification_is_error_code_driven():
classifier_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
)
source = classifier_path.read_text(encoding="utf-8")
assert "PREMIUM_KEYWORDS" not in source
assert "RATE_LIMIT_KEYWORDS" not in source
assert "normalized.includes(" not in source
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
assert "onPreAcceptFailure?: () => Promise<void>;" in source
assert "if (!accepted) {" in source
assert "await onPreAcceptFailure?.();" in source
assert "await onAcceptedStreamError?.();" in source
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
assert "setMessageDocumentsMap((prev) => {" in source
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
user_message_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/components/assistant-ui/user-message.tsx"
)
source = user_message_path.read_text(encoding="utf-8")
assert "Not sent. Edit and retry." not in source
assert "failed_pre_accept" not in source
def test_network_send_failures_use_unified_retry_toast_message():
classifier_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
)
classifier_source = classifier_path.read_text(encoding="utf-8")
request_errors_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-request-errors.ts"
)
request_errors_source = request_errors_path.read_text(encoding="utf-8")
assert '"send_failed_pre_accept"' in classifier_source
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
assert "if (withCode.code) return withCode.code;" in classifier_source
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
assert "const passthroughCodes = new Set([" in request_errors_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
assert '"THREAD_BUSY"' in request_errors_source
assert '"TURN_CANCELLING"' in request_errors_source
assert '"AUTH_EXPIRED"' in request_errors_source
assert '"UNAUTHORIZED"' in request_errors_source
assert '"RATE_LIMITED"' in request_errors_source
assert '"NETWORK_ERROR"' in request_errors_source
assert '"STREAM_PARSE_ERROR"' in request_errors_source
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
assert '"SERVER_ERROR"' in request_errors_source
assert "passthroughCodes.has(existingCode)" in request_errors_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
assert "Failed to start chat. Please try again." not in classifier_source
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
# Each flow tracks accepted boundary and passes it into shared terminal handling.
# The acceptance boundary is still meaningful post-refactor: it gates
# local-state cleanup (onPreAcceptFailure path) and lets the shared
# terminal handler distinguish pre-accept aborts from in-stream errors.
assert "let newAccepted = false;" in source
assert "let resumeAccepted = false;" in source
assert "let regenerateAccepted = false;" in source
assert "accepted: newAccepted," in source
assert "accepted: resumeAccepted," in source
assert "accepted: regenerateAccepted," in source
# NOTE: The FE-side persistence guards previously asserted here
# ("if (!resumeAccepted) return;", "if (!regenerateAccepted) return;",
# "if (newAccepted && !userPersisted) {") have been intentionally
# removed by the SSE-based message-id handshake refactor. Persistence
# is now server-authoritative: persist_user_turn / persist_assistant_shell
# run inside stream_new_chat / stream_resume_chat unconditionally and
# the FE consumes data-user-message-id / data-assistant-message-id
# SSE events to learn the canonical primary keys. There is therefore
# no FE call-site to guard, and the shared terminal handler relies
# purely on the `accepted` field above (forwarded to onAbort /
# onAcceptedStreamError) to drive UI cleanup. See
# tests/integration/chat/test_message_id_sse.py for the new
# cross-tier ID coherence guarantees.
# The TURN_CANCELLING / THREAD_BUSY retry plumbing is independent
# of the persistence refactor and must still exist on every
# start-stream fetch.
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
assert "computeFallbackTurnCancellingRetryDelay" in source
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
assert "await fetchWithTurnCancellingRetry(() =>" in source
def test_cancel_active_turn_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
assert "response_model=CancelActiveTurnResponse" in source
assert 'status="cancelling",' in source
assert 'error_code="TURN_CANCELLING",' in source
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
assert "retry_after_at=" in source
assert 'status="idle",' in source
assert 'error_code="NO_ACTIVE_TURN",' in source
def test_turn_status_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
assert "response_model=TurnStatusResponse" in source
assert "_build_turn_status_payload(thread_id)" in source
assert "Permission.CHATS_READ.value" in source
assert "_raise_if_thread_busy_for_start(" in source
def test_turn_cancelling_retry_policy_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
assert "def _compute_turn_cancelling_retry_delay(" in source
assert "retry-after-ms" in source
assert '"Retry-After"' in source
assert '"errorCode": "TURN_CANCELLING"' in source
def test_turn_status_sse_contract_exists():
stream_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
).read_text(encoding="utf-8")
state_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/streaming-state.ts"
).read_text(encoding="utf-8")
pipeline_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/stream-pipeline.ts"
).read_text(encoding="utf-8")
assert '"turn-status"' in stream_source
assert '"status": "busy"' in stream_source
assert '"status": "idle"' in stream_source
assert 'type: "data-turn-status"' in state_source
assert 'case "data-turn-status":' in pipeline_source
assert "end_turn(str(chat_id))" in stream_source
def test_chat_deepagent_forwards_resolved_model_name_to_both_builders():
"""Regression guard: both system-prompt builders in chat_deepagent.py
must receive ``model_name=_resolve_prompt_model_name(...)`` so the
provider-variant dispatch can render the right ``<provider_hints>``
block. Without this the prompt silently falls back to the empty
``"default"`` variant the original bug being fixed.
This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes`
in style: it inspects module source text + a regex to enforce the
call-site shape, not just the wrapper layer (the wrappers already
forward ``model_name`` correctly, so testing them would not catch
the actual missed plumbing).
"""
import app.agents.new_chat.chat_deepagent as chat_deepagent_module
source = inspect.getsource(chat_deepagent_module)
# Helper itself must be defined.
assert "def _resolve_prompt_model_name(" in source
# Both builder calls must forward the resolved model name. Match
# across newlines + whitespace because the kwargs are split over
# multiple lines.
pattern = re.compile(
r"build_(?:surfsense|configurable)_system_prompt\([^)]*"
r"model_name=_resolve_prompt_model_name\(",
re.DOTALL,
)
matches = pattern.findall(source)
assert len(matches) == 2, (
"Expected both system-prompt builder call sites to forward "
"`model_name=_resolve_prompt_model_name(...)`, found "
f"{len(matches)}"
)