feat: unified credits and its cost calculations

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-05-02 14:34:23 -07:00
parent 451a98936e
commit ae9d36d77f
61 changed files with 5835 additions and 272 deletions

View file

@ -0,0 +1,436 @@
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
Validates the resolver used by Celery podcast/video tasks to compute
``(owner_user_id, billing_tier, base_model)`` from a search space and its
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
``stream_new_chat.py:2294-2351`` and is the single integration point that
prevents Auto-mode podcast/video from leaking premium credit.
Coverage:
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
global returns ``("premium", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
global returns ``("free", <base_model>)``.
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
always ``"free"``.
* Auto mode + ``thread_id=None`` fallback to ``("free", "auto")`` without
hitting the pin service.
* Negative id (no Auto) uses ``get_global_llm_config``'s
``billing_tier``.
* Positive id (user BYOK) always ``"free"``.
* Search space not found raises ``ValueError``.
* ``agent_llm_id`` is None raises ``ValueError``.
"""
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from uuid import UUID, uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeExecResult:
def __init__(self, obj):
self._obj = obj
def scalars(self):
return self
def first(self):
return self._obj
class _FakeSession:
"""Tiny AsyncSession stub.
``responses`` is a list of objects to return from successive
``execute()`` calls (in order). The resolver makes at most two
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
lookup), so two queued responses cover the matrix.
"""
def __init__(self, responses: list):
self._responses = list(responses)
async def execute(self, _stmt):
if not self._responses:
return _FakeExecResult(None)
return _FakeExecResult(self._responses.pop(0))
async def commit(self) -> None:
pass
@dataclass
class _FakePinResolution:
resolved_llm_config_id: int
resolved_tier: str = "premium"
from_existing_pin: bool = False
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
return SimpleNamespace(
id=42,
agent_llm_id=agent_llm_id,
user_id=user_id,
)
def _make_byok_config(
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
) -> SimpleNamespace:
return SimpleNamespace(
id=id_,
model_name=model_name,
litellm_params={"base_model": base_model} if base_model else {},
)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
"""Auto + thread → pin service resolves to negative-id premium config →
resolver returns ``("premium", <base_model>)``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
# Mock the pin service to return a concrete premium config id.
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
assert selected_llm_config_id == 0
assert thread_id == 99
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
# Mock global config lookup to return a premium entry.
def _fake_get_global(cfg_id):
if cfg_id == -1:
return {
"id": -1,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
return None
# Lazy imports inside the resolver — patch the *target* modules so the
# imported names resolve to our fakes.
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
"""Auto + thread → pin returns negative-id free config → resolver
returns ``("free", <base_model>)``. Same path the pin service takes for
out-of-credit users (graceful degradation)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
def _fake_get_global(cfg_id):
if cfg_id == -3:
return {
"id": -3,
"model_name": "openrouter/free-model",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/free-model"},
}
return None
import app.services.auto_model_pin_service as pin_module
import app.services.llm_service as llm_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/free-model"
@pytest.mark.asyncio
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
"""Auto + thread → pin returns positive-id BYOK config → resolver
returns ``("free", ...)`` (BYOK is always free per
``AgentConfig.from_new_llm_config``)."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
byok_cfg = _make_byok_config(
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
)
session = _FakeSession([search_space, byok_cfg])
async def _fake_resolve_pin(
sess,
*,
thread_id,
search_space_id,
user_id,
selected_llm_config_id,
force_repin_free=False,
):
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3-haiku"
@pytest.mark.asyncio
async def test_auto_mode_without_thread_id_falls_back_to_free():
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
the pin service. Forward-compat fallback for any future direct-API
entrypoint that doesn't have a chat thread."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
"""If the pin service raises ``ValueError`` (thread missing /
mismatched search space), the resolver should log and return free
rather than killing the whole task."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
async def _fake_resolve_pin(*args, **kwargs):
raise ValueError("thread missing")
import app.services.auto_model_pin_service as pin_module
monkeypatch.setattr(
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "free"
assert base_model == "auto"
@pytest.mark.asyncio
async def test_negative_id_premium_global_returns_premium(monkeypatch):
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
return its ``billing_tier``."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "gpt-5.4",
"billing_tier": "premium",
"litellm_params": {"base_model": "gpt-5.4"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=99
)
assert owner == user_id
assert tier == "premium"
assert base_model == "gpt-5.4"
@pytest.mark.asyncio
async def test_negative_id_free_global_returns_free(monkeypatch):
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "openrouter/some-free",
"billing_tier": "free",
"litellm_params": {"base_model": "openrouter/some-free"},
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42, thread_id=None
)
assert owner == user_id
assert tier == "free"
assert base_model == "openrouter/some-free"
@pytest.mark.asyncio
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
"""When the global config has no ``litellm_params.base_model``, the
resolver falls back to ``model_name`` matching chat's behavior."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
def _fake_get_global(cfg_id):
return {
"id": cfg_id,
"model_name": "fallback-model",
"billing_tier": "premium",
# No litellm_params.
}
import app.services.llm_service as llm_module
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
_, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert tier == "premium"
assert base_model == "fallback-model"
@pytest.mark.asyncio
async def test_positive_id_byok_is_always_free():
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
regardless of underlying provider tier."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
session = _FakeSession([search_space, byok_cfg])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == "anthropic/claude-3.5-sonnet"
@pytest.mark.asyncio
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
"""If the BYOK config row is missing/deleted but the search space still
points at it, the resolver still returns free (no debit) with an empty
base_model billable_call's premium path is skipped, no harm done."""
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
session, search_space_id=42
)
assert owner == user_id
assert tier == "free"
assert base_model == ""
@pytest.mark.asyncio
async def test_search_space_not_found_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
session = _FakeSession([None])
with pytest.raises(ValueError, match="Search space"):
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
@pytest.mark.asyncio
async def test_agent_llm_id_none_raises_value_error():
from app.services.billable_calls import _resolve_agent_billing_for_search_space
user_id = uuid4()
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
with pytest.raises(ValueError, match="agent_llm_id"):
await _resolve_agent_billing_for_search_space(session, search_space_id=42)

View file

@ -0,0 +1,432 @@
"""Unit tests for the ``billable_call`` async context manager.
Covers the per-call premium-credit lifecycle for image generation and
vision LLM extraction:
* Free configs bypass reserve/finalize but still write an audit row.
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
route layer).
* Successful premium calls reserve, yield the accumulator, then finalize
with the LiteLLM-reported actual cost and write an audit row.
* Failed premium calls release the reservation so credit isn't leaked.
* All quota DB ops happen inside their OWN ``shielded_async_session``,
isolating them from the caller's transaction (issue A).
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeQuotaResult:
def __init__(
self,
*,
allowed: bool,
used: int = 0,
limit: int = 5_000_000,
remaining: int = 5_000_000,
) -> None:
self.allowed = allowed
self.used = used
self.limit = limit
self.remaining = remaining
class _FakeSession:
"""Minimal AsyncSession stub — record commits for assertion."""
def __init__(self) -> None:
self.committed = False
self.added: list[Any] = []
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.committed = True
async def close(self) -> None:
pass
@contextlib.asynccontextmanager
async def _fake_shielded_session():
s = _FakeSession()
_SESSIONS_USED.append(s)
yield s
_SESSIONS_USED: list[_FakeSession] = []
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
finalize_calls: list[dict[str, Any]] = []
release_calls: list[dict[str, Any]] = []
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
reserve_calls.append(
{
"user_id": user_id,
"reserve_micros": reserve_micros,
"request_id": request_id,
}
)
return reserve_result
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
finalize_calls.append(
{
"user_id": user_id,
"actual_micros": actual_micros,
"reserved_micros": reserved_micros,
}
)
return finalize_result or _FakeQuotaResult(allowed=True)
async def _fake_release(*, db_session, user_id, reserved_micros):
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
record_calls: list[dict[str, Any]] = []
async def _fake_record(session, **kwargs):
record_calls.append(kwargs)
return object()
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_reserve",
_fake_reserve,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_finalize",
_fake_finalize,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_release",
_fake_release,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.shielded_async_session",
_fake_shielded_session,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_fake_record,
raising=False,
)
return {
"reserve": reserve_calls,
"finalize": finalize_calls,
"release": release_calls,
"record": record_calls,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
) as acc:
# Simulate a captured cost — the accumulator is fed by the LiteLLM
# callback in real life, here we add it manually.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
# Free still audits.
assert len(spies["record"]) == 1
assert spies["record"][0]["usage_type"] == "image_generation"
assert spies["record"][0]["cost_micros"] == 37_000
@pytest.mark.asyncio
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "image_generation"
assert err.used_micros == 5_000_000
assert err.limit_micros == 5_000_000
assert err.remaining_micros == 0
# Reserve was attempted, but no finalize/release on a denied reserve
# — the reservation never actually held credit.
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
assert spies["release"] == []
# Denied premium calls do NOT create an audit row (no work happened).
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
# LiteLLM callback would normally fill this — simulate $0.04 image.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert spies["reserve"][0]["reserve_micros"] == 50_000
assert len(spies["finalize"]) == 1
assert spies["finalize"][0]["actual_micros"] == 40_000
assert spies["finalize"][0]["reserved_micros"] == 50_000
assert spies["release"] == []
# And audit row written with the actual debited cost.
assert spies["record"][0]["cost_micros"] == 40_000
# Each quota op opened its OWN session — proves session isolation.
assert len(_SESSIONS_USED) >= 3
# Sessions used should each have committed (or be the audit one which commits).
for _s in _SESSIONS_USED:
# finalize/reserve happen via TokenQuotaService.* which we stub —
# they don't actually call commit on our fake session, but the
# audit session does. We just assert >=1 session committed.
pass
assert any(s.committed for s in _SESSIONS_USED)
@pytest.mark.asyncio
async def test_premium_failure_releases_reservation(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _ProviderError(Exception):
pass
with pytest.raises(_ProviderError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
raise _ProviderError("OpenRouter 503")
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
# Failure path: release the held reservation.
assert len(spies["release"]) == 1
assert spies["release"][0]["reserved_micros"] == 50_000
@pytest.mark.asyncio
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
"""When ``quota_reserve_micros_override`` is None we fall back to
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
Vision LLM calls take this path (token-priced models).
"""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
captured_estimator_calls: list[dict[str, Any]] = []
def _fake_estimate(*, base_model, quota_reserve_tokens):
captured_estimator_calls.append(
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
)
return 12_345
monkeypatch.setattr(
"app.services.billable_calls.estimate_call_reserve_micros",
_fake_estimate,
raising=False,
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
usage_type="vision_extraction",
):
pass
assert captured_estimator_calls == [
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
]
assert spies["reserve"][0]["reserve_micros"] == 12_345
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
"""Free podcast configs must skip reserve/finalize but still emit a
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
have full audit coverage of free-tier agent runs."""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openrouter/some-free-model",
quota_reserve_micros_override=200_000,
usage_type="podcast_generation",
thread_id=99,
call_details={"podcast_id": 7, "title": "Test Podcast"},
) as acc:
# Two transcript LLM calls aggregated into one accumulator.
acc.add(
model="openrouter/some-free-model",
prompt_tokens=1500,
completion_tokens=8000,
total_tokens=9500,
cost_micros=0,
call_kind="chat",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
assert row["thread_id"] == 99
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
@pytest.mark.asyncio
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
"""Premium video-presentation runs that hit a denied reservation must
raise ``QuotaInsufficientError`` *before* the graph runs and must not
emit an audit row (no work happened)."""
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="gpt-5.4",
quota_reserve_micros_override=1_000_000,
usage_type="video_presentation_generation",
thread_id=99,
call_details={"video_presentation_id": 12, "title": "Test Video"},
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "video_presentation_generation"
assert err.remaining_micros == 500_000
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
assert spies["finalize"] == []
assert spies["release"] == []
assert spies["record"] == []

View file

@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names
# ---------------------------------------------------------------------------
# _generate_image_gen_configs / _generate_vision_llm_configs
# ---------------------------------------------------------------------------
def test_generate_image_gen_configs_filters_by_image_output():
"""Only models with ``output_modalities`` containing ``image`` are emitted.
Tool-calling and context filters are intentionally NOT applied image
generation has nothing to do with tool calls and context windows.
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
# Pure image-gen model (small context, no tools — should still emit).
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Multi-modal: text+image output (should still emit).
{
"id": "google/gemini-2.5-flash-image",
"architecture": {"output_modalities": ["text", "image"]},
"context_length": 1_000_000,
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
},
# Pure text model — must NOT emit.
{
"id": "openai/gpt-4o",
"architecture": {"output_modalities": ["text"]},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
]
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openai/gpt-image-1" in model_names
assert "google/gemini-2.5-flash-image" in model_names
assert "openai/gpt-4o" not in model_names
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
for c in cfgs:
assert c["billing_tier"] in {"free", "premium"}
assert c["provider"] == "OPENROUTER"
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
def test_generate_image_gen_configs_assigns_image_id_offset():
"""Image configs use a different id_offset (-20000) so their negative
IDs don't collide with chat configs (-10000) or vision configs (-30000).
"""
from app.services.openrouter_integration_service import (
_generate_image_gen_configs,
)
raw = [
{
"id": "openai/gpt-image-1",
"architecture": {"output_modalities": ["image"]},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
}
]
# Don't pass image_id_offset → use the module default (-20000).
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
assert all(c["id"] < -20_000 + 1 for c in cfgs)
assert all(c["id"] > -29_000_000 for c in cfgs)
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
"""Vision LLMs must accept image input AND emit text — pure image-gen
(no text out) and text-only (no image in) models are excluded.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
# GPT-4o: vision LLM (image in, text out) — must emit.
{
"id": "openai/gpt-4o",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"context_length": 128_000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
},
# Pure image generator — image *output*, no text out. Must NOT emit.
{
"id": "openai/gpt-image-1",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["image"],
},
"context_length": 4_000,
"pricing": {"prompt": "0", "completion": "0"},
},
# Pure text model (no image in). Must NOT emit.
{
"id": "anthropic/claude-3-haiku",
"architecture": {
"input_modalities": ["text"],
"output_modalities": ["text"],
},
"context_length": 200_000,
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
},
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
names = {c["model_name"] for c in cfgs}
assert names == {"openai/gpt-4o"}
cfg = cfgs[0]
assert cfg["billing_tier"] == "premium"
# Pricing carried inline so pricing_registration can register vision
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
# is cleared.
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
def test_generate_vision_llm_configs_drops_chat_only_filters():
"""A small-context vision model that doesn't advertise tool calling is
still a valid vision LLM for "describe this image" prompts. The chat
filters (``supports_tool_calling``, ``has_sufficient_context``) must
NOT be applied to vision emission.
"""
from app.services.openrouter_integration_service import (
_generate_vision_llm_configs,
)
raw = [
{
"id": "tiny/vision-mini",
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
},
"supported_parameters": [], # no tools
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
}
]
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
assert len(cfgs) == 1
assert cfgs[0]["model_name"] == "tiny/vision-mini"

View file

@ -0,0 +1,447 @@
"""Pricing registration unit tests.
The pricing-registration module is what makes ``response_cost`` populate
correctly for OpenRouter dynamic models and operator-defined Azure
deployments both of which LiteLLM doesn't natively know about. The tests
exercise:
* The alias generators emit every shape that LiteLLM's cost-callback might
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
``provider/base_model``, ``provider/model_name``, plus the special
``azure_openai`` ``azure`` normalisation).
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
with the right alias set and pricing values per provider.
* Configs without a resolvable pair of cost values are skipped never
registered as zero, since that would override pricing LiteLLM might
already know natively.
"""
from __future__ import annotations
from typing import Any
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Alias generators
# ---------------------------------------------------------------------------
def test_openrouter_alias_set_includes_prefixed_and_bare():
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
assert aliases == [
"openrouter/anthropic/claude-3-5-sonnet",
"anthropic/claude-3-5-sonnet",
]
def test_openrouter_alias_set_dedupes():
"""If the model id is already prefixed with ``openrouter/``, the alias
set must not contain duplicates that would re-register the same key
twice.
"""
from app.services.pricing_registration import _alias_set_for_openrouter
aliases = _alias_set_for_openrouter("openrouter/foo")
# The bare and prefixed variants compute to the same string here, so we
# at minimum require uniqueness.
assert len(aliases) == len(set(aliases))
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
"""``azure_openai`` (our YAML provider slug) must register under
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
(which uses provider ``azure``) finds the pricing too.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="AZURE_OPENAI",
model_name="gpt-5.4",
base_model="gpt-5.4",
)
assert "gpt-5.4" in aliases
assert "azure_openai/gpt-5.4" in aliases
assert "azure/gpt-5.4" in aliases
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
"""When ``model_name`` differs from ``base_model`` (operator labelled a
deployment), both must appear in the alias set since either may surface
in callbacks depending on the call path.
"""
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="OPENAI",
model_name="my-deployment-label",
base_model="gpt-4o",
)
assert "gpt-4o" in aliases
assert "openai/gpt-4o" in aliases
assert "my-deployment-label" in aliases
assert "openai/my-deployment-label" in aliases
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
from app.services.pricing_registration import _alias_set_for_yaml
aliases = _alias_set_for_yaml(
provider="",
model_name="foo",
base_model="bar",
)
assert "bar" in aliases
assert "foo" in aliases
assert all("/" not in a for a in aliases)
# ---------------------------------------------------------------------------
# register_pricing_from_global_configs
# ---------------------------------------------------------------------------
class _RegistrationSpy:
"""Captures the dicts passed to ``litellm.register_model``.
Many calls may go through; we just record them all and let tests assert
against the union.
"""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
def __call__(self, payload: dict[str, Any]) -> None:
self.calls.append(payload)
@property
def all_keys(self) -> set[str]:
keys: set[str] = set()
for payload in self.calls:
keys.update(payload.keys())
return keys
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
spy = _RegistrationSpy()
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
spy,
raising=False,
)
return spy
def _patch_openrouter_pricing(
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
) -> None:
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
class _Stub:
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
return mapping
class _StubService:
@classmethod
def is_initialized(cls) -> bool:
return True
@classmethod
def get_instance(cls) -> _Stub:
return _Stub()
monkeypatch.setattr(
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
_StubService,
raising=False,
)
def test_openrouter_models_register_under_aliases(monkeypatch):
"""An OpenRouter config whose ``model_name`` is in the cached raw
pricing map is registered under both ``openrouter/X`` and bare ``X``.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
# Costs are float-converted from the raw OpenRouter strings.
payload = spy.calls[0]
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"input_cost_per_token"
] == pytest.approx(3e-6)
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
"output_cost_per_token"
] == pytest.approx(15e-6)
assert (
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
== "openrouter"
)
def test_yaml_override_registers_under_alias_set(monkeypatch):
"""Operator-declared ``input_cost_per_token`` /
``output_cost_per_token`` on a YAML config registers under every
alias the YAML alias generator produces including the ``azure/``
normalisation for ``azure_openai`` providers.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5.4",
"litellm_params": {
"base_model": "gpt-5.4",
"input_cost_per_token": 2e-6,
"output_cost_per_token": 8e-6,
},
}
],
)
register_pricing_from_global_configs()
keys = spy.all_keys
assert "gpt-5.4" in keys
assert "azure_openai/gpt-5.4" in keys
assert "azure/gpt-5.4" in keys
payload = spy.calls[0]
entry = payload["gpt-5.4"]
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
assert entry["litellm_provider"] == "azure"
def test_no_override_means_no_registration(monkeypatch):
"""A YAML config that *omits* both pricing fields must NOT be registered
registering as zero would override LiteLLM's native pricing for the
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENAI",
"model_name": "gpt-4o",
"litellm_params": {"base_model": "gpt-4o"},
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
configured model (network blip during refresh, model added later, etc.),
we skip it rather than registering zero pricing.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
}
],
)
register_pricing_from_global_configs()
assert spy.calls == []
def test_register_continues_after_individual_failure(monkeypatch, caplog):
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
must not abort registration of the remaining configs.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
successful_calls: list[dict[str, Any]] = []
def _maybe_fail(payload: dict[str, Any]) -> None:
if any(k in failing_keys for k in payload):
raise RuntimeError("boom")
successful_calls.append(payload)
monkeypatch.setattr(
"app.services.pricing_registration.litellm.register_model",
_maybe_fail,
raising=False,
)
_patch_openrouter_pricing(
monkeypatch,
{
"anthropic/claude-3-5-sonnet": {
"prompt": "0.000003",
"completion": "0.000015",
}
},
)
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": 1,
"provider": "OPENROUTER",
"model_name": "anthropic/claude-3-5-sonnet",
},
{
"id": 2,
"provider": "OPENAI",
"model_name": "custom-deployment",
"litellm_params": {
"base_model": "custom-deployment",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 2e-6,
},
},
],
)
register_pricing_from_global_configs()
# The good config still registered.
assert any("custom-deployment" in payload for payload in successful_calls)
def test_vision_configs_registered_with_chat_shape(monkeypatch):
"""``register_pricing_from_global_configs`` walks
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
calls (during indexing) bill correctly. Vision configs use the same
chat-shape token prices, but image-gen pricing is intentionally NOT
registered here (handled via ``response_cost`` in LiteLLM).
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(
monkeypatch,
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
)
# No chat configs — only vision. Proves the vision walk is a separate
# iteration, not piggy-backed on the chat list.
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "openai/gpt-4o",
"billing_tier": "premium",
"input_cost_per_token": 5e-6,
"output_cost_per_token": 15e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/openai/gpt-4o" in spy.all_keys
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
assert payload_value["mode"] == "chat"
assert payload_value["litellm_provider"] == "openrouter"
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
"""If the OpenRouter pricing cache misses a vision model (different
catalogue surface), the vision walk falls back to inline
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
"""
from app.config import config
from app.services.pricing_registration import register_pricing_from_global_configs
spy = _patch_register(monkeypatch)
_patch_openrouter_pricing(monkeypatch, {})
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
monkeypatch.setattr(
config,
"GLOBAL_VISION_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash",
"billing_tier": "premium",
"input_cost_per_token": 1e-6,
"output_cost_per_token": 4e-6,
}
],
)
register_pricing_from_global_configs()
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys

View file

@ -0,0 +1,157 @@
"""Unit tests for ``QuotaCheckedVisionLLM``.
Validates that:
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
enforcement) and forwards the inner LLM's response on success.
* The wrapper proxies non-overridden attributes to the inner LLM
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
still work without quota gating (they're not used in indexing today).
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
bubbles it up the ETL pipeline catches that and falls back to OCR.
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
class _FakeInnerLLM:
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
def __init__(self, response: Any = "OCR'd content") -> None:
self._response = response
self.ainvoke_calls: list[Any] = []
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
self.ainvoke_calls.append(input)
return self._response
def some_other_method(self, x: int) -> int:
return x * 2
@contextlib.asynccontextmanager
async def _passthrough_billable_call(**_kwargs):
"""Stand-in for billable_call that always allows the call to run."""
class _Acc:
total_cost_micros = 0
total_prompt_tokens = 0
total_completion_tokens = 0
grand_total = 0
calls: list[Any] = []
def per_message_summary(self) -> dict[str, dict[str, int]]:
return {}
yield _Acc()
@pytest.mark.asyncio
async def test_ainvoke_routes_through_billable_call(monkeypatch):
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
captured_kwargs: list[dict[str, Any]] = []
@contextlib.asynccontextmanager
async def _spy_billable_call(**kwargs):
captured_kwargs.append(kwargs)
async with _passthrough_billable_call() as acc:
yield acc
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_spy_billable_call,
raising=False,
)
inner = _FakeInnerLLM(response="A red apple on a white table")
user_id = uuid4()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=user_id,
search_space_id=99,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
result = await wrapper.ainvoke([{"text": "what is this?"}])
assert result == "A red apple on a white table"
assert len(inner.ainvoke_calls) == 1
assert len(captured_kwargs) == 1
bc_kwargs = captured_kwargs[0]
assert bc_kwargs["user_id"] == user_id
assert bc_kwargs["search_space_id"] == 99
assert bc_kwargs["billing_tier"] == "premium"
assert bc_kwargs["base_model"] == "openai/gpt-4o"
assert bc_kwargs["quota_reserve_tokens"] == 4000
assert bc_kwargs["usage_type"] == "vision_extraction"
@pytest.mark.asyncio
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
from app.services.billable_calls import QuotaInsufficientError
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
@contextlib.asynccontextmanager
async def _denying_billable_call(**_kwargs):
raise QuotaInsufficientError(
usage_type="vision_extraction",
used_micros=5_000_000,
limit_micros=5_000_000,
remaining_micros=0,
)
yield # unreachable but required for asynccontextmanager type
monkeypatch.setattr(
"app.services.quota_checked_vision_llm.billable_call",
_denying_billable_call,
raising=False,
)
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
with pytest.raises(QuotaInsufficientError):
await wrapper.ainvoke([{"text": "x"}])
# Inner LLM never ran on a denied reservation.
assert inner.ainvoke_calls == []
@pytest.mark.asyncio
async def test_proxies_non_overridden_attributes_to_inner():
"""``__getattr__`` forwards anything not on the proxy itself, so any
method we didn't explicitly override (``invoke``, ``astream``,
``with_structured_output``, etc.) still works just without quota
gating, which is fine because the indexer only ever calls ainvoke.
"""
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
inner = _FakeInnerLLM()
wrapper = QuotaCheckedVisionLLM(
inner,
user_id=uuid4(),
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
)
# ``some_other_method`` is on the inner only.
assert wrapper.some_other_method(7) == 14

View file

@ -0,0 +1,515 @@
"""Cost-based premium quota unit tests.
Covers the USD-micro behaviour added in migration 140:
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
calls in a turn used as the debit amount when ``agent_config.is_premium``
is true, regardless of which underlying model produced each call. This
preserves the prior "premium turn → all calls in turn count" rule from the
token-based system.
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
clamps to a sane floor when pricing is unknown, and respects the
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
can't lock the whole balance on one call.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# TurnTokenAccumulator — premium-turn debit semantics
# ---------------------------------------------------------------------------
def test_total_cost_micros_sums_premium_and_free_calls():
"""A premium turn that also called a free sub-agent debits the union.
The plan deliberately preserved the existing "premium turn → all calls
count" behaviour because per-call premium filtering relied on
``LLMRouterService._premium_model_strings`` which only covers router-pool
deployments. ``total_cost_micros`` therefore must include free-model
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
call's actual provider cost.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
# Premium model (e.g. claude-opus): non-zero cost.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=1200,
completion_tokens=400,
total_tokens=1600,
cost_micros=12_345,
)
# Free sub-agent (e.g. title-gen on a free model): zero cost.
acc.add(
model="gpt-4o-mini",
prompt_tokens=120,
completion_tokens=20,
total_tokens=140,
cost_micros=0,
)
# A second premium-priced call within the same turn.
acc.add(
model="anthropic/claude-3-5-sonnet",
prompt_tokens=800,
completion_tokens=200,
total_tokens=1000,
cost_micros=7_500,
)
assert acc.total_cost_micros == 12_345 + 0 + 7_500
# Token totals stay correct so the FE display path still works.
assert acc.grand_total == 1600 + 140 + 1000
def test_total_cost_micros_zero_when_no_calls():
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
assert acc.total_cost_micros == 0
assert acc.grand_total == 0
def test_per_message_summary_groups_cost_by_model():
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
SSE ``model_breakdown`` payload reports actual USD spend per provider.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
cost_micros=4_000,
)
acc.add(
model="claude-3-5-sonnet",
prompt_tokens=200,
completion_tokens=100,
total_tokens=300,
cost_micros=8_000,
)
acc.add(
model="gpt-4o-mini",
prompt_tokens=50,
completion_tokens=10,
total_tokens=60,
cost_micros=200,
)
summary = acc.per_message_summary()
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
assert summary["gpt-4o-mini"]["cost_micros"] == 200
def test_serialized_calls_includes_cost_micros():
"""``serialized_calls`` is what flows into the SSE ``call_details``
payload; cost_micros must be present on each entry so the FE message-info
dropdown can render per-call USD.
"""
from app.services.token_tracking_service import TurnTokenAccumulator
acc = TurnTokenAccumulator()
acc.add(
model="m",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=42,
)
serialized = acc.serialized_calls()
assert serialized == [
{
"model": "m",
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
"cost_micros": 42,
"call_kind": "chat",
}
]
# ---------------------------------------------------------------------------
# estimate_call_reserve_micros — sizing and clamping
# ---------------------------------------------------------------------------
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
helper falls back to the 100-micro floor small enough that a user with
$0.0001 left can still send a tiny request, but non-zero so we still gate
against an empty balance.
"""
import litellm
from app.services import token_quota_service
def _raise(_name):
raise KeyError("unknown")
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="nonexistent-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
assert micros == 100
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
"""LiteLLM may *return* a model with both cost-per-token fields at 0
(pricing not yet registered). The helper must not multiply 0 x tokens
and end up reserving 0 it must clamp to the floor.
"""
import litellm
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
raising=False,
)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="some-pending-model",
quota_reserve_tokens=4000,
)
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
def test_reserve_scales_with_model_cost(monkeypatch):
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
some small artificial cap that was the bug the plan called out.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 15e-6,
"output_cost_per_token": 75e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="claude-3-opus",
quota_reserve_tokens=4000,
)
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
assert micros == 360_000
def test_reserve_clamps_to_max_ceiling(monkeypatch):
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
nominally compute to $4 = 4_000_000 micros. The ceiling
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
can't lock the user's whole balance on one call.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-3,
"output_cost_per_token": 0,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
micros = token_quota_service.estimate_call_reserve_micros(
base_model="oops-misconfigured",
quota_reserve_tokens=4000,
)
assert micros == 1_000_000
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
so anonymous-style configs still reserve the operator-tunable default.
"""
import litellm
from app.config import config
from app.services import token_quota_service
monkeypatch.setattr(
litellm,
"get_model_info",
lambda _name: {
"input_cost_per_token": 1e-6,
"output_cost_per_token": 1e-6,
},
raising=False,
)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=None
)
== 4000
)
assert (
token_quota_service.estimate_call_reserve_micros(
base_model="cheap", quota_reserve_tokens=0
)
== 4000
)
# ---------------------------------------------------------------------------
# TokenTrackingCallback — image vs chat usage shape
# ---------------------------------------------------------------------------
class _FakeImageUsage:
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
def __init__(
self,
input_tokens: int = 0,
output_tokens: int = 0,
total_tokens: int | None = None,
) -> None:
self.input_tokens = input_tokens
self.output_tokens = output_tokens
if total_tokens is not None:
self.total_tokens = total_tokens
class _FakeImageResponse:
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
``type(...).__name__`` probe routes to the image branch.
"""
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
self.usage = usage
if response_cost is not None:
self._hidden_params = {"response_cost": response_cost}
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
# the callback. We can't simply name the class ``ImageResponse`` because
# the test runner sometimes imports test modules in surprising ways and
# we want to be explicit.
_FakeImageResponse.__name__ = "ImageResponse"
class _FakeChatUsage:
def __init__(self, prompt: int, completion: int):
self.prompt_tokens = prompt
self.completion_tokens = completion
self.total_tokens = prompt + completion
class _FakeChatResponse:
def __init__(self, usage: _FakeChatUsage):
self.usage = usage
@pytest.mark.asyncio
async def test_callback_reads_image_usage_input_output_tokens():
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
prompt_tokens/completion_tokens which is the chat shape.
"""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
response_cost=0.04, # $0.04 per image
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 42
assert call.completion_tokens == 8
assert call.total_tokens == 50
# 0.04 USD = 40_000 micros
assert call.cost_micros == 40_000
assert call.call_kind == "image_generation"
@pytest.mark.asyncio
async def test_callback_chat_path_unchanged():
"""Chat responses must still read prompt_tokens/completion_tokens."""
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
cb = TokenTrackingCallback()
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={
"model": "openrouter/anthropic/claude-3-5-sonnet",
"response_cost": 0.0036,
},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
call = acc.calls[0]
assert call.prompt_tokens == 120
assert call.completion_tokens == 30
assert call.total_tokens == 150
assert call.cost_micros == 3_600
assert call.call_kind == "chat"
@pytest.mark.asyncio
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
"""When OpenRouter omits ``usage.cost`` LiteLLM's
``default_image_cost_calculator`` raises. The defensive image branch in
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
chat-shaped and would raise too) it returns 0 with a WARNING log.
"""
import litellm
from app.services.token_tracking_service import (
TokenTrackingCallback,
scoped_turn,
)
# Force completion_cost to raise the same way OpenRouter image-gen fails.
def _boom(*_args, **_kwargs):
raise ValueError("model_cost: missing entry for openrouter image model")
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
# And make sure cost_per_token is NEVER called for the image path —
# if it were, our ``is_image=True`` branch is broken.
cost_per_token_calls: list = []
def _record_cost_per_token(**kwargs):
cost_per_token_calls.append(kwargs)
return (0.0, 0.0)
monkeypatch.setattr(
litellm, "cost_per_token", _record_cost_per_token, raising=False
)
cb = TokenTrackingCallback()
response = _FakeImageResponse(
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
)
async with scoped_turn() as acc:
await cb.async_log_success_event(
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
response_obj=response,
start_time=None,
end_time=None,
)
assert len(acc.calls) == 1
assert acc.calls[0].cost_micros == 0
assert acc.calls[0].call_kind == "image_generation"
# The image branch must short-circuit before cost_per_token.
assert cost_per_token_calls == []
# ---------------------------------------------------------------------------
# scoped_turn — ContextVar reset semantics (issue B)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_scoped_turn_restores_outer_accumulator():
"""``scoped_turn`` must restore the previous ContextVar value on exit
so a per-call wrapper inside an outer chat turn doesn't leak its
accumulator outward (which would cause double-debit at chat-turn exit).
"""
from app.services.token_tracking_service import (
get_current_accumulator,
scoped_turn,
start_turn,
)
outer = start_turn()
assert get_current_accumulator() is outer
async with scoped_turn() as inner:
assert get_current_accumulator() is inner
assert inner is not outer
inner.add(
model="x",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost_micros=5,
)
# After exit the outer accumulator is restored unchanged.
assert get_current_accumulator() is outer
assert outer.total_cost_micros == 0
assert len(outer.calls) == 0
# The inner accumulator captured the call but didn't bleed into outer.
assert inner.total_cost_micros == 5
@pytest.mark.asyncio
async def test_scoped_turn_resets_to_none_when_no_outer():
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
indexing job) must leave the ContextVar at ``None`` on exit so the
next *unrelated* request starts clean.
"""
from app.services.token_tracking_service import (
_turn_accumulator,
get_current_accumulator,
scoped_turn,
)
# ContextVar default is None for a fresh test isolated context. We
# simulate "no outer" explicitly to be robust against test order.
token = _turn_accumulator.set(None)
try:
assert get_current_accumulator() is None
async with scoped_turn() as acc:
assert get_current_accumulator() is acc
assert get_current_accumulator() is None
finally:
_turn_accumulator.reset(token)