mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
feat: unified credits and its cost calculations
This commit is contained in:
parent
451a98936e
commit
ae9d36d77f
61 changed files with 5835 additions and 272 deletions
|
|
@ -0,0 +1,436 @@
|
|||
"""Unit tests for ``_resolve_agent_billing_for_search_space``.
|
||||
|
||||
Validates the resolver used by Celery podcast/video tasks to compute
|
||||
``(owner_user_id, billing_tier, base_model)`` from a search space and its
|
||||
agent LLM config. The resolver mirrors chat's billing-resolution pattern at
|
||||
``stream_new_chat.py:2294-2351`` and is the single integration point that
|
||||
prevents Auto-mode podcast/video from leaking premium credit.
|
||||
|
||||
Coverage:
|
||||
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id premium
|
||||
global → returns ``("premium", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a negative-id free
|
||||
global → returns ``("free", <base_model>)``.
|
||||
* Auto mode + ``thread_id`` set, pin resolves to a positive-id BYOK config
|
||||
→ always ``"free"``.
|
||||
* Auto mode + ``thread_id=None`` → fallback to ``("free", "auto")`` without
|
||||
hitting the pin service.
|
||||
* Negative id (no Auto) → uses ``get_global_llm_config``'s
|
||||
``billing_tier``.
|
||||
* Positive id (user BYOK) → always ``"free"``.
|
||||
* Search space not found → raises ``ValueError``.
|
||||
* ``agent_llm_id`` is None → raises ``ValueError``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeExecResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._obj
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Tiny AsyncSession stub.
|
||||
|
||||
``responses`` is a list of objects to return from successive
|
||||
``execute()`` calls (in order). The resolver makes at most two
|
||||
``execute()`` calls (search-space lookup, then optionally NewLLMConfig
|
||||
lookup), so two queued responses cover the matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: list):
|
||||
self._responses = list(responses)
|
||||
|
||||
async def execute(self, _stmt):
|
||||
if not self._responses:
|
||||
return _FakeExecResult(None)
|
||||
return _FakeExecResult(self._responses.pop(0))
|
||||
|
||||
async def commit(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakePinResolution:
|
||||
resolved_llm_config_id: int
|
||||
resolved_tier: str = "premium"
|
||||
from_existing_pin: bool = False
|
||||
|
||||
|
||||
def _make_search_space(*, agent_llm_id: int | None, user_id: UUID) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=42,
|
||||
agent_llm_id=agent_llm_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_byok_config(
|
||||
*, id_: int, base_model: str | None = None, model_name: str = "gpt-byok"
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=id_,
|
||||
model_name=model_name,
|
||||
litellm_params={"base_model": base_model} if base_model else {},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_premium_global(monkeypatch):
|
||||
"""Auto + thread → pin service resolves to negative-id premium config →
|
||||
resolver returns ``("premium", <base_model>)``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
# Mock the pin service to return a concrete premium config id.
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
assert selected_llm_config_id == 0
|
||||
assert thread_id == 99
|
||||
return _FakePinResolution(resolved_llm_config_id=-1, resolved_tier="premium")
|
||||
|
||||
# Mock global config lookup to return a premium entry.
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -1:
|
||||
return {
|
||||
"id": -1,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
return None
|
||||
|
||||
# Lazy imports inside the resolver — patch the *target* modules so the
|
||||
# imported names resolve to our fakes.
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_free_global(monkeypatch):
|
||||
"""Auto + thread → pin returns negative-id free config → resolver
|
||||
returns ``("free", <base_model>)``. Same path the pin service takes for
|
||||
out-of-credit users (graceful degradation)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=-3, resolved_tier="free")
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
if cfg_id == -3:
|
||||
return {
|
||||
"id": -3,
|
||||
"model_name": "openrouter/free-model",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/free-model"},
|
||||
}
|
||||
return None
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/free-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_with_thread_id_resolves_to_byok_is_free(monkeypatch):
|
||||
"""Auto + thread → pin returns positive-id BYOK config → resolver
|
||||
returns ``("free", ...)`` (BYOK is always free per
|
||||
``AgentConfig.from_new_llm_config``)."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=0, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(
|
||||
id_=17, base_model="anthropic/claude-3-haiku", model_name="my-claude"
|
||||
)
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
async def _fake_resolve_pin(
|
||||
sess,
|
||||
*,
|
||||
thread_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
selected_llm_config_id,
|
||||
force_repin_free=False,
|
||||
):
|
||||
return _FakePinResolution(resolved_llm_config_id=17, resolved_tier="free")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3-haiku"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_without_thread_id_falls_back_to_free():
|
||||
"""Auto + ``thread_id=None`` → ``("free", "auto")`` without invoking
|
||||
the pin service. Forward-compat fallback for any future direct-API
|
||||
entrypoint that doesn't have a chat thread."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_mode_pin_failure_falls_back_to_free(monkeypatch):
|
||||
"""If the pin service raises ``ValueError`` (thread missing /
|
||||
mismatched search space), the resolver should log and return free
|
||||
rather than killing the whole task."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=0, user_id=user_id)])
|
||||
|
||||
async def _fake_resolve_pin(*args, **kwargs):
|
||||
raise ValueError("thread missing")
|
||||
|
||||
import app.services.auto_model_pin_service as pin_module
|
||||
|
||||
monkeypatch.setattr(
|
||||
pin_module, "resolve_or_get_pinned_llm_config_id", _fake_resolve_pin
|
||||
)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "auto"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_premium_global_returns_premium(monkeypatch):
|
||||
"""Explicit negative agent_llm_id → ``get_global_llm_config`` →
|
||||
return its ``billing_tier``."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-1, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "gpt-5.4",
|
||||
"billing_tier": "premium",
|
||||
"litellm_params": {"base_model": "gpt-5.4"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=99
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "premium"
|
||||
assert base_model == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_free_global_returns_free(monkeypatch):
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-2, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "openrouter/some-free",
|
||||
"billing_tier": "free",
|
||||
"litellm_params": {"base_model": "openrouter/some-free"},
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42, thread_id=None
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "openrouter/some-free"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_id_missing_base_model_falls_back_to_model_name(monkeypatch):
|
||||
"""When the global config has no ``litellm_params.base_model``, the
|
||||
resolver falls back to ``model_name`` — matching chat's behavior."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=-5, user_id=user_id)])
|
||||
|
||||
def _fake_get_global(cfg_id):
|
||||
return {
|
||||
"id": cfg_id,
|
||||
"model_name": "fallback-model",
|
||||
"billing_tier": "premium",
|
||||
# No litellm_params.
|
||||
}
|
||||
|
||||
import app.services.llm_service as llm_module
|
||||
|
||||
monkeypatch.setattr(llm_module, "get_global_llm_config", _fake_get_global)
|
||||
|
||||
_, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert tier == "premium"
|
||||
assert base_model == "fallback-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_is_always_free():
|
||||
"""Positive agent_llm_id → user-owned BYOK NewLLMConfig → always free,
|
||||
regardless of underlying provider tier."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
search_space = _make_search_space(agent_llm_id=23, user_id=user_id)
|
||||
byok_cfg = _make_byok_config(id_=23, base_model="anthropic/claude-3.5-sonnet")
|
||||
session = _FakeSession([search_space, byok_cfg])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_id_byok_missing_returns_free_with_empty_base_model():
|
||||
"""If the BYOK config row is missing/deleted but the search space still
|
||||
points at it, the resolver still returns free (no debit) with an empty
|
||||
base_model — billable_call's premium path is skipped, no harm done."""
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=99, user_id=user_id)])
|
||||
|
||||
owner, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id=42
|
||||
)
|
||||
|
||||
assert owner == user_id
|
||||
assert tier == "free"
|
||||
assert base_model == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_space_not_found_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
session = _FakeSession([None])
|
||||
|
||||
with pytest.raises(ValueError, match="Search space"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=999)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_llm_id_none_raises_value_error():
|
||||
from app.services.billable_calls import _resolve_agent_billing_for_search_space
|
||||
|
||||
user_id = uuid4()
|
||||
session = _FakeSession([_make_search_space(agent_llm_id=None, user_id=user_id)])
|
||||
|
||||
with pytest.raises(ValueError, match="agent_llm_id"):
|
||||
await _resolve_agent_billing_for_search_space(session, search_space_id=42)
|
||||
432
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
432
surfsense_backend/tests/unit/services/test_billable_call.py
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
"""Unit tests for the ``billable_call`` async context manager.
|
||||
|
||||
Covers the per-call premium-credit lifecycle for image generation and
|
||||
vision LLM extraction:
|
||||
|
||||
* Free configs bypass reserve/finalize but still write an audit row.
|
||||
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
|
||||
route layer).
|
||||
* Successful premium calls reserve, yield the accumulator, then finalize
|
||||
with the LiteLLM-reported actual cost — and write an audit row.
|
||||
* Failed premium calls release the reservation so credit isn't leaked.
|
||||
* All quota DB ops happen inside their OWN ``shielded_async_session``,
|
||||
isolating them from the caller's transaction (issue A).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fakes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeQuotaResult:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
allowed: bool,
|
||||
used: int = 0,
|
||||
limit: int = 5_000_000,
|
||||
remaining: int = 5_000_000,
|
||||
) -> None:
|
||||
self.allowed = allowed
|
||||
self.used = used
|
||||
self.limit = limit
|
||||
self.remaining = remaining
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Minimal AsyncSession stub — record commits for assertion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.committed = False
|
||||
self.added: list[Any] = []
|
||||
|
||||
def add(self, obj: Any) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
async def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _fake_shielded_session():
|
||||
s = _FakeSession()
|
||||
_SESSIONS_USED.append(s)
|
||||
yield s
|
||||
|
||||
|
||||
_SESSIONS_USED: list[_FakeSession] = []
|
||||
|
||||
|
||||
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
|
||||
"""Wire fake reserve/finalize/release/session helpers."""
|
||||
_SESSIONS_USED.clear()
|
||||
reserve_calls: list[dict[str, Any]] = []
|
||||
finalize_calls: list[dict[str, Any]] = []
|
||||
release_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
|
||||
reserve_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"reserve_micros": reserve_micros,
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
return reserve_result
|
||||
|
||||
async def _fake_finalize(
|
||||
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
||||
):
|
||||
finalize_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"actual_micros": actual_micros,
|
||||
"reserved_micros": reserved_micros,
|
||||
}
|
||||
)
|
||||
return finalize_result or _FakeQuotaResult(allowed=True)
|
||||
|
||||
async def _fake_release(*, db_session, user_id, reserved_micros):
|
||||
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
|
||||
|
||||
record_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _fake_record(session, **kwargs):
|
||||
record_calls.append(kwargs)
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_reserve",
|
||||
_fake_reserve,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_finalize",
|
||||
_fake_finalize,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.TokenQuotaService.premium_release",
|
||||
_fake_release,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.shielded_async_session",
|
||||
_fake_shielded_session,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.record_token_usage",
|
||||
_fake_record,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"reserve": reserve_calls,
|
||||
"finalize": finalize_calls,
|
||||
"release": release_calls,
|
||||
"record": record_calls,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openai/gpt-image-1",
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# Simulate a captured cost — the accumulator is fed by the LiteLLM
|
||||
# callback in real life, here we add it manually.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=37_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Free still audits.
|
||||
assert len(spies["record"]) == 1
|
||||
assert spies["record"][0]["usage_type"] == "image_generation"
|
||||
assert spies["record"][0]["cost_micros"] == 37_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "image_generation"
|
||||
assert err.used_micros == 5_000_000
|
||||
assert err.limit_micros == 5_000_000
|
||||
assert err.remaining_micros == 0
|
||||
# Reserve was attempted, but no finalize/release on a denied reserve
|
||||
# — the reservation never actually held credit.
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
# Denied premium calls do NOT create an audit row (no work happened).
|
||||
assert spies["record"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
) as acc:
|
||||
# LiteLLM callback would normally fill this — simulate $0.04 image.
|
||||
acc.add(
|
||||
model="openai/gpt-image-1",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost_micros=40_000,
|
||||
call_kind="image_generation",
|
||||
)
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["reserve"][0]["reserve_micros"] == 50_000
|
||||
assert len(spies["finalize"]) == 1
|
||||
assert spies["finalize"][0]["actual_micros"] == 40_000
|
||||
assert spies["finalize"][0]["reserved_micros"] == 50_000
|
||||
assert spies["release"] == []
|
||||
# And audit row written with the actual debited cost.
|
||||
assert spies["record"][0]["cost_micros"] == 40_000
|
||||
# Each quota op opened its OWN session — proves session isolation.
|
||||
assert len(_SESSIONS_USED) >= 3
|
||||
# Sessions used should each have committed (or be the audit one which commits).
|
||||
for _s in _SESSIONS_USED:
|
||||
# finalize/reserve happen via TokenQuotaService.* which we stub —
|
||||
# they don't actually call commit on our fake session, but the
|
||||
# audit session does. We just assert >=1 session committed.
|
||||
pass
|
||||
assert any(s.committed for s in _SESSIONS_USED)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_failure_releases_reservation(monkeypatch):
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
class _ProviderError(Exception):
|
||||
pass
|
||||
|
||||
with pytest.raises(_ProviderError):
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-image-1",
|
||||
quota_reserve_micros_override=50_000,
|
||||
usage_type="image_generation",
|
||||
):
|
||||
raise _ProviderError("OpenRouter 503")
|
||||
|
||||
assert len(spies["reserve"]) == 1
|
||||
assert spies["finalize"] == []
|
||||
# Failure path: release the held reservation.
|
||||
assert len(spies["release"]) == 1
|
||||
assert spies["release"][0]["reserved_micros"] == 50_000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
||||
"""When ``quota_reserve_micros_override`` is None we fall back to
|
||||
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
|
||||
Vision LLM calls take this path (token-priced models).
|
||||
"""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
|
||||
captured_estimator_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_estimate(*, base_model, quota_reserve_tokens):
|
||||
captured_estimator_calls.append(
|
||||
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
|
||||
)
|
||||
return 12_345
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.billable_calls.estimate_call_reserve_micros",
|
||||
_fake_estimate,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
user_id = uuid4()
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
usage_type="vision_extraction",
|
||||
):
|
||||
pass
|
||||
|
||||
assert captured_estimator_calls == [
|
||||
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
|
||||
]
|
||||
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Podcast / video-presentation usage_type coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
||||
"""Free podcast configs must skip reserve/finalize but still emit a
|
||||
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
|
||||
have full audit coverage of free-tier agent runs."""
|
||||
from app.services.billable_calls import billable_call
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="free",
|
||||
base_model="openrouter/some-free-model",
|
||||
quota_reserve_micros_override=200_000,
|
||||
usage_type="podcast_generation",
|
||||
thread_id=99,
|
||||
call_details={"podcast_id": 7, "title": "Test Podcast"},
|
||||
) as acc:
|
||||
# Two transcript LLM calls aggregated into one accumulator.
|
||||
acc.add(
|
||||
model="openrouter/some-free-model",
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=8000,
|
||||
total_tokens=9500,
|
||||
cost_micros=0,
|
||||
call_kind="chat",
|
||||
)
|
||||
|
||||
assert spies["reserve"] == []
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
|
||||
assert len(spies["record"]) == 1
|
||||
row = spies["record"][0]
|
||||
assert row["usage_type"] == "podcast_generation"
|
||||
assert row["thread_id"] == 99
|
||||
assert row["search_space_id"] == 42
|
||||
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
|
||||
"""Premium video-presentation runs that hit a denied reservation must
|
||||
raise ``QuotaInsufficientError`` *before* the graph runs and must not
|
||||
emit an audit row (no work happened)."""
|
||||
from app.services.billable_calls import (
|
||||
QuotaInsufficientError,
|
||||
billable_call,
|
||||
)
|
||||
|
||||
spies = _patch_isolation_layer(
|
||||
monkeypatch,
|
||||
reserve_result=_FakeQuotaResult(
|
||||
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
|
||||
),
|
||||
)
|
||||
user_id = uuid4()
|
||||
|
||||
with pytest.raises(QuotaInsufficientError) as exc_info:
|
||||
async with billable_call(
|
||||
user_id=user_id,
|
||||
search_space_id=42,
|
||||
billing_tier="premium",
|
||||
base_model="gpt-5.4",
|
||||
quota_reserve_micros_override=1_000_000,
|
||||
usage_type="video_presentation_generation",
|
||||
thread_id=99,
|
||||
call_details={"video_presentation_id": 12, "title": "Test Video"},
|
||||
):
|
||||
pytest.fail("body should not run when reserve is denied")
|
||||
|
||||
err = exc_info.value
|
||||
assert err.usage_type == "video_presentation_generation"
|
||||
assert err.remaining_micros == 500_000
|
||||
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
|
||||
assert spies["finalize"] == []
|
||||
assert spies["release"] == []
|
||||
assert spies["record"] == []
|
||||
|
|
@ -214,3 +214,159 @@ def test_generate_configs_drops_non_text_and_non_tool_models():
|
|||
assert "openai/gpt-4o" in model_names
|
||||
assert "openai/dall-e" not in model_names
|
||||
assert "openai/completion-only" not in model_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_image_gen_configs / _generate_vision_llm_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_filters_by_image_output():
|
||||
"""Only models with ``output_modalities`` containing ``image`` are emitted.
|
||||
Tool-calling and context filters are intentionally NOT applied — image
|
||||
generation has nothing to do with tool calls and context windows.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# Pure image-gen model (small context, no tools — should still emit).
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Multi-modal: text+image output (should still emit).
|
||||
{
|
||||
"id": "google/gemini-2.5-flash-image",
|
||||
"architecture": {"output_modalities": ["text", "image"]},
|
||||
"context_length": 1_000_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000004"},
|
||||
},
|
||||
# Pure text model — must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {"output_modalities": ["text"]},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
model_names = {c["model_name"] for c in cfgs}
|
||||
assert "openai/gpt-image-1" in model_names
|
||||
assert "google/gemini-2.5-flash-image" in model_names
|
||||
assert "openai/gpt-4o" not in model_names
|
||||
|
||||
# Each config must carry ``billing_tier`` for routing in image_generation_routes.
|
||||
for c in cfgs:
|
||||
assert c["billing_tier"] in {"free", "premium"}
|
||||
assert c["provider"] == "OPENROUTER"
|
||||
assert c[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
def test_generate_image_gen_configs_assigns_image_id_offset():
|
||||
"""Image configs use a different id_offset (-20000) so their negative
|
||||
IDs don't collide with chat configs (-10000) or vision configs (-30000).
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_image_gen_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {"output_modalities": ["image"]},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
}
|
||||
]
|
||||
# Don't pass image_id_offset → use the module default (-20000).
|
||||
cfgs = _generate_image_gen_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert all(c["id"] < -20_000 + 1 for c in cfgs)
|
||||
assert all(c["id"] > -29_000_000 for c in cfgs)
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_filters_by_image_input_text_output():
|
||||
"""Vision LLMs must accept image input AND emit text — pure image-gen
|
||||
(no text out) and text-only (no image in) models are excluded.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
# GPT-4o: vision LLM (image in, text out) — must emit.
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 128_000,
|
||||
"pricing": {"prompt": "0.000005", "completion": "0.000015"},
|
||||
},
|
||||
# Pure image generator — image *output*, no text out. Must NOT emit.
|
||||
{
|
||||
"id": "openai/gpt-image-1",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["image"],
|
||||
},
|
||||
"context_length": 4_000,
|
||||
"pricing": {"prompt": "0", "completion": "0"},
|
||||
},
|
||||
# Pure text model (no image in). Must NOT emit.
|
||||
{
|
||||
"id": "anthropic/claude-3-haiku",
|
||||
"architecture": {
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"context_length": 200_000,
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000005"},
|
||||
},
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
names = {c["model_name"] for c in cfgs}
|
||||
assert names == {"openai/gpt-4o"}
|
||||
|
||||
cfg = cfgs[0]
|
||||
assert cfg["billing_tier"] == "premium"
|
||||
# Pricing carried inline so pricing_registration can register vision
|
||||
# under ``openrouter/openai/gpt-4o`` even if the chat catalogue cache
|
||||
# is cleared.
|
||||
assert cfg["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert cfg["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
assert cfg[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||
|
||||
|
||||
def test_generate_vision_llm_configs_drops_chat_only_filters():
|
||||
"""A small-context vision model that doesn't advertise tool calling is
|
||||
still a valid vision LLM for "describe this image" prompts. The chat
|
||||
filters (``supports_tool_calling``, ``has_sufficient_context``) must
|
||||
NOT be applied to vision emission.
|
||||
"""
|
||||
from app.services.openrouter_integration_service import (
|
||||
_generate_vision_llm_configs,
|
||||
)
|
||||
|
||||
raw = [
|
||||
{
|
||||
"id": "tiny/vision-mini",
|
||||
"architecture": {
|
||||
"input_modalities": ["text", "image"],
|
||||
"output_modalities": ["text"],
|
||||
},
|
||||
"supported_parameters": [], # no tools
|
||||
"context_length": 4_000, # well below MIN_CONTEXT_LENGTH
|
||||
"pricing": {"prompt": "0.0000001", "completion": "0.0000005"},
|
||||
}
|
||||
]
|
||||
|
||||
cfgs = _generate_vision_llm_configs(raw, dict(_SETTINGS_BASE))
|
||||
assert len(cfgs) == 1
|
||||
assert cfgs[0]["model_name"] == "tiny/vision-mini"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,447 @@
|
|||
"""Pricing registration unit tests.
|
||||
|
||||
The pricing-registration module is what makes ``response_cost`` populate
|
||||
correctly for OpenRouter dynamic models and operator-defined Azure
|
||||
deployments — both of which LiteLLM doesn't natively know about. The tests
|
||||
exercise:
|
||||
|
||||
* The alias generators emit every shape that LiteLLM's cost-callback might
|
||||
use (``openrouter/X`` and bare ``X``; YAML-defined ``base_model``,
|
||||
``provider/base_model``, ``provider/model_name``, plus the special
|
||||
``azure_openai`` → ``azure`` normalisation).
|
||||
* ``register_pricing_from_global_configs`` calls ``litellm.register_model``
|
||||
with the right alias set and pricing values per provider.
|
||||
* Configs without a resolvable pair of cost values are skipped — never
|
||||
registered as zero, since that would override pricing LiteLLM might
|
||||
already know natively.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_openrouter_alias_set_includes_prefixed_and_bare():
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("anthropic/claude-3-5-sonnet")
|
||||
assert aliases == [
|
||||
"openrouter/anthropic/claude-3-5-sonnet",
|
||||
"anthropic/claude-3-5-sonnet",
|
||||
]
|
||||
|
||||
|
||||
def test_openrouter_alias_set_dedupes():
|
||||
"""If the model id is already prefixed with ``openrouter/``, the alias
|
||||
set must not contain duplicates that would re-register the same key
|
||||
twice.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_openrouter
|
||||
|
||||
aliases = _alias_set_for_openrouter("openrouter/foo")
|
||||
# The bare and prefixed variants compute to the same string here, so we
|
||||
# at minimum require uniqueness.
|
||||
assert len(aliases) == len(set(aliases))
|
||||
|
||||
|
||||
def test_yaml_alias_set_for_azure_openai_normalises_to_azure():
|
||||
"""``azure_openai`` (our YAML provider slug) must register under
|
||||
``azure/<name>`` so the LiteLLM Router's deployment-resolution path
|
||||
(which uses provider ``azure``) finds the pricing too.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="AZURE_OPENAI",
|
||||
model_name="gpt-5.4",
|
||||
base_model="gpt-5.4",
|
||||
)
|
||||
assert "gpt-5.4" in aliases
|
||||
assert "azure_openai/gpt-5.4" in aliases
|
||||
assert "azure/gpt-5.4" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_distinguishes_model_name_and_base_model():
|
||||
"""When ``model_name`` differs from ``base_model`` (operator labelled a
|
||||
deployment), both must appear in the alias set since either may surface
|
||||
in callbacks depending on the call path.
|
||||
"""
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="OPENAI",
|
||||
model_name="my-deployment-label",
|
||||
base_model="gpt-4o",
|
||||
)
|
||||
assert "gpt-4o" in aliases
|
||||
assert "openai/gpt-4o" in aliases
|
||||
assert "my-deployment-label" in aliases
|
||||
assert "openai/my-deployment-label" in aliases
|
||||
|
||||
|
||||
def test_yaml_alias_set_omits_provider_prefix_when_provider_blank():
|
||||
from app.services.pricing_registration import _alias_set_for_yaml
|
||||
|
||||
aliases = _alias_set_for_yaml(
|
||||
provider="",
|
||||
model_name="foo",
|
||||
base_model="bar",
|
||||
)
|
||||
assert "bar" in aliases
|
||||
assert "foo" in aliases
|
||||
assert all("/" not in a for a in aliases)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_pricing_from_global_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RegistrationSpy:
|
||||
"""Captures the dicts passed to ``litellm.register_model``.
|
||||
|
||||
Many calls may go through; we just record them all and let tests assert
|
||||
against the union.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
def __call__(self, payload: dict[str, Any]) -> None:
|
||||
self.calls.append(payload)
|
||||
|
||||
@property
|
||||
def all_keys(self) -> set[str]:
|
||||
keys: set[str] = set()
|
||||
for payload in self.calls:
|
||||
keys.update(payload.keys())
|
||||
return keys
|
||||
|
||||
|
||||
def _patch_register(monkeypatch: pytest.MonkeyPatch) -> _RegistrationSpy:
|
||||
spy = _RegistrationSpy()
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
spy,
|
||||
raising=False,
|
||||
)
|
||||
return spy
|
||||
|
||||
|
||||
def _patch_openrouter_pricing(
|
||||
monkeypatch: pytest.MonkeyPatch, mapping: dict[str, dict[str, str]]
|
||||
) -> None:
|
||||
"""Pretend the OpenRouter integration is initialised with ``mapping``."""
|
||||
|
||||
class _Stub:
|
||||
def get_raw_pricing(self) -> dict[str, dict[str, str]]:
|
||||
return mapping
|
||||
|
||||
class _StubService:
|
||||
@classmethod
|
||||
def is_initialized(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> _Stub:
|
||||
return _Stub()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.openrouter_integration_service.OpenRouterIntegrationService",
|
||||
_StubService,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_models_register_under_aliases(monkeypatch):
|
||||
"""An OpenRouter config whose ``model_name`` is in the cached raw
|
||||
pricing map is registered under both ``openrouter/X`` and bare ``X``.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
assert "anthropic/claude-3-5-sonnet" in spy.all_keys
|
||||
# Costs are float-converted from the raw OpenRouter strings.
|
||||
payload = spy.calls[0]
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"input_cost_per_token"
|
||||
] == pytest.approx(3e-6)
|
||||
assert payload["openrouter/anthropic/claude-3-5-sonnet"][
|
||||
"output_cost_per_token"
|
||||
] == pytest.approx(15e-6)
|
||||
assert (
|
||||
payload["openrouter/anthropic/claude-3-5-sonnet"]["litellm_provider"]
|
||||
== "openrouter"
|
||||
)
|
||||
|
||||
|
||||
def test_yaml_override_registers_under_alias_set(monkeypatch):
|
||||
"""Operator-declared ``input_cost_per_token`` /
|
||||
``output_cost_per_token`` on a YAML config registers under every
|
||||
alias the YAML alias generator produces — including the ``azure/``
|
||||
normalisation for ``azure_openai`` providers.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "AZURE_OPENAI",
|
||||
"model_name": "gpt-5.4",
|
||||
"litellm_params": {
|
||||
"base_model": "gpt-5.4",
|
||||
"input_cost_per_token": 2e-6,
|
||||
"output_cost_per_token": 8e-6,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
keys = spy.all_keys
|
||||
assert "gpt-5.4" in keys
|
||||
assert "azure_openai/gpt-5.4" in keys
|
||||
assert "azure/gpt-5.4" in keys
|
||||
|
||||
payload = spy.calls[0]
|
||||
entry = payload["gpt-5.4"]
|
||||
assert entry["input_cost_per_token"] == pytest.approx(2e-6)
|
||||
assert entry["output_cost_per_token"] == pytest.approx(8e-6)
|
||||
assert entry["litellm_provider"] == "azure"
|
||||
|
||||
|
||||
def test_no_override_means_no_registration(monkeypatch):
|
||||
"""A YAML config that *omits* both pricing fields must NOT be registered
|
||||
— registering as zero would override LiteLLM's native pricing for the
|
||||
``base_model`` key (e.g. ``gpt-4o``) and silently make every user's
|
||||
bill drop to $0. Fail-safe is "skip and warn", not "register zero".
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "gpt-4o",
|
||||
"litellm_params": {"base_model": "gpt-4o"},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_openrouter_skipped_when_pricing_missing(monkeypatch):
|
||||
"""If the OpenRouter raw-pricing cache doesn't carry an entry for a
|
||||
configured model (network blip during refresh, model added later, etc.),
|
||||
we skip it rather than registering zero pricing.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch, {"some/other-model": {"prompt": "1", "completion": "1"}}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert spy.calls == []
|
||||
|
||||
|
||||
def test_register_continues_after_individual_failure(monkeypatch, caplog):
|
||||
"""A single bad ``register_model`` call (e.g. raising LiteLLM error)
|
||||
must not abort registration of the remaining configs.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
failing_keys: set[str] = {"anthropic/claude-3-5-sonnet"}
|
||||
successful_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _maybe_fail(payload: dict[str, Any]) -> None:
|
||||
if any(k in failing_keys for k in payload):
|
||||
raise RuntimeError("boom")
|
||||
successful_calls.append(payload)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.pricing_registration.litellm.register_model",
|
||||
_maybe_fail,
|
||||
raising=False,
|
||||
)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{
|
||||
"anthropic/claude-3-5-sonnet": {
|
||||
"prompt": "0.000003",
|
||||
"completion": "0.000015",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "anthropic/claude-3-5-sonnet",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"provider": "OPENAI",
|
||||
"model_name": "custom-deployment",
|
||||
"litellm_params": {
|
||||
"base_model": "custom-deployment",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 2e-6,
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
# The good config still registered.
|
||||
assert any("custom-deployment" in payload for payload in successful_calls)
|
||||
|
||||
|
||||
def test_vision_configs_registered_with_chat_shape(monkeypatch):
|
||||
"""``register_pricing_from_global_configs`` walks
|
||||
``GLOBAL_VISION_LLM_CONFIGS`` in addition to the chat configs so vision
|
||||
calls (during indexing) bill correctly. Vision configs use the same
|
||||
chat-shape token prices, but image-gen pricing is intentionally NOT
|
||||
registered here (handled via ``response_cost`` in LiteLLM).
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(
|
||||
monkeypatch,
|
||||
{"openai/gpt-4o": {"prompt": "0.000005", "completion": "0.000015"}},
|
||||
)
|
||||
|
||||
# No chat configs — only vision. Proves the vision walk is a separate
|
||||
# iteration, not piggy-backed on the chat list.
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "openai/gpt-4o",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 5e-6,
|
||||
"output_cost_per_token": 15e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/openai/gpt-4o" in spy.all_keys
|
||||
payload_value = spy.calls[0]["openrouter/openai/gpt-4o"]
|
||||
assert payload_value["mode"] == "chat"
|
||||
assert payload_value["litellm_provider"] == "openrouter"
|
||||
assert payload_value["input_cost_per_token"] == pytest.approx(5e-6)
|
||||
assert payload_value["output_cost_per_token"] == pytest.approx(15e-6)
|
||||
|
||||
|
||||
def test_vision_with_inline_pricing_when_or_cache_missing(monkeypatch):
|
||||
"""If the OpenRouter pricing cache misses a vision model (different
|
||||
catalogue surface), the vision walk falls back to inline
|
||||
``input_cost_per_token``/``output_cost_per_token`` on the cfg itself.
|
||||
"""
|
||||
from app.config import config
|
||||
from app.services.pricing_registration import register_pricing_from_global_configs
|
||||
|
||||
spy = _patch_register(monkeypatch)
|
||||
_patch_openrouter_pricing(monkeypatch, {})
|
||||
|
||||
monkeypatch.setattr(config, "GLOBAL_LLM_CONFIGS", [])
|
||||
monkeypatch.setattr(
|
||||
config,
|
||||
"GLOBAL_VISION_LLM_CONFIGS",
|
||||
[
|
||||
{
|
||||
"id": -1,
|
||||
"provider": "OPENROUTER",
|
||||
"model_name": "google/gemini-2.5-flash",
|
||||
"billing_tier": "premium",
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 4e-6,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
register_pricing_from_global_configs()
|
||||
|
||||
assert "openrouter/google/gemini-2.5-flash" in spy.all_keys
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
"""Unit tests for ``QuotaCheckedVisionLLM``.
|
||||
|
||||
Validates that:
|
||||
|
||||
* Calling ``ainvoke`` routes through ``billable_call`` (premium credit
|
||||
enforcement) and forwards the inner LLM's response on success.
|
||||
* The wrapper proxies non-overridden attributes to the inner LLM
|
||||
(``__getattr__``) so ``invoke`` / ``astream`` / ``with_structured_output``
|
||||
still work without quota gating (they're not used in indexing today).
|
||||
* When ``billable_call`` raises ``QuotaInsufficientError`` the wrapper
|
||||
bubbles it up — the ETL pipeline catches that and falls back to OCR.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeInnerLLM:
|
||||
"""Stand-in for ``langchain_litellm.ChatLiteLLM``."""
|
||||
|
||||
def __init__(self, response: Any = "OCR'd content") -> None:
|
||||
self._response = response
|
||||
self.ainvoke_calls: list[Any] = []
|
||||
|
||||
async def ainvoke(self, input: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
self.ainvoke_calls.append(input)
|
||||
return self._response
|
||||
|
||||
def some_other_method(self, x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _passthrough_billable_call(**_kwargs):
|
||||
"""Stand-in for billable_call that always allows the call to run."""
|
||||
|
||||
class _Acc:
|
||||
total_cost_micros = 0
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
grand_total = 0
|
||||
calls: list[Any] = []
|
||||
|
||||
def per_message_summary(self) -> dict[str, dict[str, int]]:
|
||||
return {}
|
||||
|
||||
yield _Acc()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_routes_through_billable_call(monkeypatch):
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
captured_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _spy_billable_call(**kwargs):
|
||||
captured_kwargs.append(kwargs)
|
||||
async with _passthrough_billable_call() as acc:
|
||||
yield acc
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_spy_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM(response="A red apple on a white table")
|
||||
user_id = uuid4()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=user_id,
|
||||
search_space_id=99,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
result = await wrapper.ainvoke([{"text": "what is this?"}])
|
||||
assert result == "A red apple on a white table"
|
||||
assert len(inner.ainvoke_calls) == 1
|
||||
assert len(captured_kwargs) == 1
|
||||
bc_kwargs = captured_kwargs[0]
|
||||
assert bc_kwargs["user_id"] == user_id
|
||||
assert bc_kwargs["search_space_id"] == 99
|
||||
assert bc_kwargs["billing_tier"] == "premium"
|
||||
assert bc_kwargs["base_model"] == "openai/gpt-4o"
|
||||
assert bc_kwargs["quota_reserve_tokens"] == 4000
|
||||
assert bc_kwargs["usage_type"] == "vision_extraction"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_propagates_quota_insufficient_error(monkeypatch):
|
||||
from app.services.billable_calls import QuotaInsufficientError
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _denying_billable_call(**_kwargs):
|
||||
raise QuotaInsufficientError(
|
||||
usage_type="vision_extraction",
|
||||
used_micros=5_000_000,
|
||||
limit_micros=5_000_000,
|
||||
remaining_micros=0,
|
||||
)
|
||||
yield # unreachable but required for asynccontextmanager type
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.services.quota_checked_vision_llm.billable_call",
|
||||
_denying_billable_call,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
with pytest.raises(QuotaInsufficientError):
|
||||
await wrapper.ainvoke([{"text": "x"}])
|
||||
|
||||
# Inner LLM never ran on a denied reservation.
|
||||
assert inner.ainvoke_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxies_non_overridden_attributes_to_inner():
|
||||
"""``__getattr__`` forwards anything not on the proxy itself, so any
|
||||
method we didn't explicitly override (``invoke``, ``astream``,
|
||||
``with_structured_output``, etc.) still works — just without quota
|
||||
gating, which is fine because the indexer only ever calls ainvoke.
|
||||
"""
|
||||
from app.services.quota_checked_vision_llm import QuotaCheckedVisionLLM
|
||||
|
||||
inner = _FakeInnerLLM()
|
||||
wrapper = QuotaCheckedVisionLLM(
|
||||
inner,
|
||||
user_id=uuid4(),
|
||||
search_space_id=1,
|
||||
billing_tier="premium",
|
||||
base_model="openai/gpt-4o",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
|
||||
# ``some_other_method`` is on the inner only.
|
||||
assert wrapper.some_other_method(7) == 14
|
||||
|
|
@ -0,0 +1,515 @@
|
|||
"""Cost-based premium quota unit tests.
|
||||
|
||||
Covers the USD-micro behaviour added in migration 140:
|
||||
|
||||
* ``TurnTokenAccumulator.total_cost_micros`` sums ``cost_micros`` across all
|
||||
calls in a turn — used as the debit amount when ``agent_config.is_premium``
|
||||
is true, regardless of which underlying model produced each call. This
|
||||
preserves the prior "premium turn → all calls in turn count" rule from the
|
||||
token-based system.
|
||||
* ``estimate_call_reserve_micros`` scales linearly with model pricing,
|
||||
clamps to a sane floor when pricing is unknown, and respects the
|
||||
``QUOTA_MAX_RESERVE_MICROS`` ceiling so a misconfigured "$1000/M" entry
|
||||
can't lock the whole balance on one call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TurnTokenAccumulator — premium-turn debit semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_total_cost_micros_sums_premium_and_free_calls():
|
||||
"""A premium turn that also called a free sub-agent debits the union.
|
||||
|
||||
The plan deliberately preserved the existing "premium turn → all calls
|
||||
count" behaviour because per-call premium filtering relied on
|
||||
``LLMRouterService._premium_model_strings`` which only covers router-pool
|
||||
deployments. ``total_cost_micros`` therefore must include free-model
|
||||
calls (whose ``cost_micros`` is typically ``0``) as well as the premium
|
||||
call's actual provider cost.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
# Premium model (e.g. claude-opus): non-zero cost.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=1200,
|
||||
completion_tokens=400,
|
||||
total_tokens=1600,
|
||||
cost_micros=12_345,
|
||||
)
|
||||
# Free sub-agent (e.g. title-gen on a free model): zero cost.
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=120,
|
||||
completion_tokens=20,
|
||||
total_tokens=140,
|
||||
cost_micros=0,
|
||||
)
|
||||
# A second premium-priced call within the same turn.
|
||||
acc.add(
|
||||
model="anthropic/claude-3-5-sonnet",
|
||||
prompt_tokens=800,
|
||||
completion_tokens=200,
|
||||
total_tokens=1000,
|
||||
cost_micros=7_500,
|
||||
)
|
||||
|
||||
assert acc.total_cost_micros == 12_345 + 0 + 7_500
|
||||
# Token totals stay correct so the FE display path still works.
|
||||
assert acc.grand_total == 1600 + 140 + 1000
|
||||
|
||||
|
||||
def test_total_cost_micros_zero_when_no_calls():
|
||||
"""An empty accumulator must report zero cost (no division-by-zero, no None)."""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
assert acc.total_cost_micros == 0
|
||||
assert acc.grand_total == 0
|
||||
|
||||
|
||||
def test_per_message_summary_groups_cost_by_model():
|
||||
"""``per_message_summary`` must accumulate ``cost_micros`` per model so the
|
||||
SSE ``model_breakdown`` payload reports actual USD spend per provider.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
cost_micros=4_000,
|
||||
)
|
||||
acc.add(
|
||||
model="claude-3-5-sonnet",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=100,
|
||||
total_tokens=300,
|
||||
cost_micros=8_000,
|
||||
)
|
||||
acc.add(
|
||||
model="gpt-4o-mini",
|
||||
prompt_tokens=50,
|
||||
completion_tokens=10,
|
||||
total_tokens=60,
|
||||
cost_micros=200,
|
||||
)
|
||||
|
||||
summary = acc.per_message_summary()
|
||||
assert summary["claude-3-5-sonnet"]["cost_micros"] == 12_000
|
||||
assert summary["claude-3-5-sonnet"]["total_tokens"] == 450
|
||||
assert summary["gpt-4o-mini"]["cost_micros"] == 200
|
||||
|
||||
|
||||
def test_serialized_calls_includes_cost_micros():
|
||||
"""``serialized_calls`` is what flows into the SSE ``call_details``
|
||||
payload; cost_micros must be present on each entry so the FE message-info
|
||||
dropdown can render per-call USD.
|
||||
"""
|
||||
from app.services.token_tracking_service import TurnTokenAccumulator
|
||||
|
||||
acc = TurnTokenAccumulator()
|
||||
acc.add(
|
||||
model="m",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=42,
|
||||
)
|
||||
serialized = acc.serialized_calls()
|
||||
assert serialized == [
|
||||
{
|
||||
"model": "m",
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
"cost_micros": 42,
|
||||
"call_kind": "chat",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# estimate_call_reserve_micros — sizing and clamping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_model_unknown(monkeypatch):
|
||||
"""If LiteLLM doesn't know the model, ``get_model_info`` raises and the
|
||||
helper falls back to the 100-micro floor — small enough that a user with
|
||||
$0.0001 left can still send a tiny request, but non-zero so we still gate
|
||||
against an empty balance.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
def _raise(_name):
|
||||
raise KeyError("unknown")
|
||||
|
||||
monkeypatch.setattr(litellm, "get_model_info", _raise, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="nonexistent-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
assert micros == 100
|
||||
|
||||
|
||||
def test_reserve_returns_floor_when_pricing_is_zero(monkeypatch):
|
||||
"""LiteLLM may *return* a model with both cost-per-token fields at 0
|
||||
(pricing not yet registered). The helper must not multiply 0 x tokens
|
||||
and end up reserving 0 — it must clamp to the floor.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {"input_cost_per_token": 0, "output_cost_per_token": 0},
|
||||
raising=False,
|
||||
)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="some-pending-model",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == token_quota_service._QUOTA_MIN_RESERVE_MICROS
|
||||
|
||||
|
||||
def test_reserve_scales_with_model_cost(monkeypatch):
|
||||
"""Claude-Opus-priced model with 4000 reserve_tokens reserves
|
||||
~$0.36 = 360_000 micros. Critically this must NOT be clamped down to
|
||||
some small artificial cap — that was the bug the plan called out.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 15e-6,
|
||||
"output_cost_per_token": 75e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="claude-3-opus",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
# 4000 * (15e-6 + 75e-6) = 4000 * 90e-6 = 0.36 USD = 360_000 micros.
|
||||
assert micros == 360_000
|
||||
|
||||
|
||||
def test_reserve_clamps_to_max_ceiling(monkeypatch):
|
||||
"""A misconfigured "$1000 / M" model with 4000 reserve_tokens would
|
||||
nominally compute to $4 = 4_000_000 micros. The ceiling
|
||||
``QUOTA_MAX_RESERVE_MICROS`` must clamp that so a bad pricing entry
|
||||
can't lock the user's whole balance on one call.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-3,
|
||||
"output_cost_per_token": 0,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
micros = token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="oops-misconfigured",
|
||||
quota_reserve_tokens=4000,
|
||||
)
|
||||
assert micros == 1_000_000
|
||||
|
||||
|
||||
def test_reserve_uses_default_when_quota_reserve_tokens_missing(monkeypatch):
|
||||
"""Per-config ``quota_reserve_tokens`` is optional; when ``None`` or
|
||||
zero, the helper must fall back to the global ``QUOTA_MAX_RESERVE_PER_CALL``
|
||||
so anonymous-style configs still reserve the operator-tunable default.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.config import config
|
||||
from app.services import token_quota_service
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm,
|
||||
"get_model_info",
|
||||
lambda _name: {
|
||||
"input_cost_per_token": 1e-6,
|
||||
"output_cost_per_token": 1e-6,
|
||||
},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_PER_CALL", 2000, raising=False)
|
||||
monkeypatch.setattr(config, "QUOTA_MAX_RESERVE_MICROS", 1_000_000, raising=False)
|
||||
|
||||
# 2000 * (1e-6 + 1e-6) = 4e-3 USD = 4000 micros
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=None
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
assert (
|
||||
token_quota_service.estimate_call_reserve_micros(
|
||||
base_model="cheap", quota_reserve_tokens=0
|
||||
)
|
||||
== 4000
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TokenTrackingCallback — image vs chat usage shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeImageUsage:
|
||||
"""Mimics LiteLLM's ``ImageUsage`` (input_tokens / output_tokens shape)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
total_tokens: int | None = None,
|
||||
) -> None:
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
if total_tokens is not None:
|
||||
self.total_tokens = total_tokens
|
||||
|
||||
|
||||
class _FakeImageResponse:
|
||||
"""Mimics LiteLLM's ``ImageResponse`` — same name so the callback's
|
||||
``type(...).__name__`` probe routes to the image branch.
|
||||
"""
|
||||
|
||||
def __init__(self, usage: _FakeImageUsage, response_cost: float | None = None):
|
||||
self.usage = usage
|
||||
if response_cost is not None:
|
||||
self._hidden_params = {"response_cost": response_cost}
|
||||
|
||||
|
||||
# Re-tag the helper class as ``ImageResponse`` for the type-name probe in
|
||||
# the callback. We can't simply name the class ``ImageResponse`` because
|
||||
# the test runner sometimes imports test modules in surprising ways and
|
||||
# we want to be explicit.
|
||||
_FakeImageResponse.__name__ = "ImageResponse"
|
||||
|
||||
|
||||
class _FakeChatUsage:
|
||||
def __init__(self, prompt: int, completion: int):
|
||||
self.prompt_tokens = prompt
|
||||
self.completion_tokens = completion
|
||||
self.total_tokens = prompt + completion
|
||||
|
||||
|
||||
class _FakeChatResponse:
|
||||
def __init__(self, usage: _FakeChatUsage):
|
||||
self.usage = usage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_reads_image_usage_input_output_tokens():
|
||||
"""``TokenTrackingCallback`` must read ``input_tokens``/``output_tokens``
|
||||
for ``ImageResponse`` (LiteLLM's ImageUsage shape), NOT
|
||||
prompt_tokens/completion_tokens which is the chat shape.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=42, output_tokens=8, total_tokens=50),
|
||||
response_cost=0.04, # $0.04 per image
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openai/gpt-image-1", "response_cost": 0.04},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 42
|
||||
assert call.completion_tokens == 8
|
||||
assert call.total_tokens == 50
|
||||
# 0.04 USD = 40_000 micros
|
||||
assert call.cost_micros == 40_000
|
||||
assert call.call_kind == "image_generation"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_chat_path_unchanged():
|
||||
"""Chat responses must still read prompt_tokens/completion_tokens."""
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeChatResponse(_FakeChatUsage(prompt=120, completion=30))
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={
|
||||
"model": "openrouter/anthropic/claude-3-5-sonnet",
|
||||
"response_cost": 0.0036,
|
||||
},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
assert len(acc.calls) == 1
|
||||
call = acc.calls[0]
|
||||
assert call.prompt_tokens == 120
|
||||
assert call.completion_tokens == 30
|
||||
assert call.total_tokens == 150
|
||||
assert call.cost_micros == 3_600
|
||||
assert call.call_kind == "chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_image_missing_response_cost_falls_back_to_zero(monkeypatch):
|
||||
"""When OpenRouter omits ``usage.cost`` LiteLLM's
|
||||
``default_image_cost_calculator`` raises. The defensive image branch in
|
||||
``_extract_cost_usd`` must NOT call ``cost_per_token`` (which is
|
||||
chat-shaped and would raise too) — it returns 0 with a WARNING log.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
from app.services.token_tracking_service import (
|
||||
TokenTrackingCallback,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# Force completion_cost to raise the same way OpenRouter image-gen fails.
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise ValueError("model_cost: missing entry for openrouter image model")
|
||||
|
||||
monkeypatch.setattr(litellm, "completion_cost", _boom, raising=False)
|
||||
|
||||
# And make sure cost_per_token is NEVER called for the image path —
|
||||
# if it were, our ``is_image=True`` branch is broken.
|
||||
cost_per_token_calls: list = []
|
||||
|
||||
def _record_cost_per_token(**kwargs):
|
||||
cost_per_token_calls.append(kwargs)
|
||||
return (0.0, 0.0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
litellm, "cost_per_token", _record_cost_per_token, raising=False
|
||||
)
|
||||
|
||||
cb = TokenTrackingCallback()
|
||||
response = _FakeImageResponse(
|
||||
usage=_FakeImageUsage(input_tokens=7, output_tokens=0)
|
||||
)
|
||||
|
||||
async with scoped_turn() as acc:
|
||||
await cb.async_log_success_event(
|
||||
kwargs={"model": "openrouter/google/gemini-2.5-flash-image"},
|
||||
response_obj=response,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
)
|
||||
|
||||
assert len(acc.calls) == 1
|
||||
assert acc.calls[0].cost_micros == 0
|
||||
assert acc.calls[0].call_kind == "image_generation"
|
||||
# The image branch must short-circuit before cost_per_token.
|
||||
assert cost_per_token_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scoped_turn — ContextVar reset semantics (issue B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_restores_outer_accumulator():
|
||||
"""``scoped_turn`` must restore the previous ContextVar value on exit
|
||||
so a per-call wrapper inside an outer chat turn doesn't leak its
|
||||
accumulator outward (which would cause double-debit at chat-turn exit).
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
start_turn,
|
||||
)
|
||||
|
||||
outer = start_turn()
|
||||
assert get_current_accumulator() is outer
|
||||
|
||||
async with scoped_turn() as inner:
|
||||
assert get_current_accumulator() is inner
|
||||
assert inner is not outer
|
||||
inner.add(
|
||||
model="x",
|
||||
prompt_tokens=1,
|
||||
completion_tokens=1,
|
||||
total_tokens=2,
|
||||
cost_micros=5,
|
||||
)
|
||||
|
||||
# After exit the outer accumulator is restored unchanged.
|
||||
assert get_current_accumulator() is outer
|
||||
assert outer.total_cost_micros == 0
|
||||
assert len(outer.calls) == 0
|
||||
# The inner accumulator captured the call but didn't bleed into outer.
|
||||
assert inner.total_cost_micros == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_turn_resets_to_none_when_no_outer():
|
||||
"""Running ``scoped_turn`` outside any chat turn (e.g. a background
|
||||
indexing job) must leave the ContextVar at ``None`` on exit so the
|
||||
next *unrelated* request starts clean.
|
||||
"""
|
||||
from app.services.token_tracking_service import (
|
||||
_turn_accumulator,
|
||||
get_current_accumulator,
|
||||
scoped_turn,
|
||||
)
|
||||
|
||||
# ContextVar default is None for a fresh test isolated context. We
|
||||
# simulate "no outer" explicitly to be robust against test order.
|
||||
token = _turn_accumulator.set(None)
|
||||
try:
|
||||
assert get_current_accumulator() is None
|
||||
async with scoped_turn() as acc:
|
||||
assert get_current_accumulator() is acc
|
||||
assert get_current_accumulator() is None
|
||||
finally:
|
||||
_turn_accumulator.reset(token)
|
||||
Loading…
Add table
Add a link
Reference in a new issue