SurfSense/surfsense_backend/tests/unit/services/test_billable_call.py
2026-05-02 14:34:23 -07:00

432 lines
14 KiB
Python

"""Unit tests for the ``billable_call`` async context manager.
Covers the per-call premium-credit lifecycle for image generation and
vision LLM extraction:
* Free configs bypass reserve/finalize but still write an audit row.
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
route layer).
* Successful premium calls reserve, yield the accumulator, then finalize
with the LiteLLM-reported actual cost — and write an audit row.
* Failed premium calls release the reservation so credit isn't leaked.
* All quota DB ops happen inside their OWN ``shielded_async_session``,
isolating them from the caller's transaction (issue A).
"""
from __future__ import annotations
import contextlib
from typing import Any
from uuid import uuid4
import pytest
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Fakes
# ---------------------------------------------------------------------------
class _FakeQuotaResult:
def __init__(
self,
*,
allowed: bool,
used: int = 0,
limit: int = 5_000_000,
remaining: int = 5_000_000,
) -> None:
self.allowed = allowed
self.used = used
self.limit = limit
self.remaining = remaining
class _FakeSession:
"""Minimal AsyncSession stub — record commits for assertion."""
def __init__(self) -> None:
self.committed = False
self.added: list[Any] = []
def add(self, obj: Any) -> None:
self.added.append(obj)
async def commit(self) -> None:
self.committed = True
async def close(self) -> None:
pass
@contextlib.asynccontextmanager
async def _fake_shielded_session():
s = _FakeSession()
_SESSIONS_USED.append(s)
yield s
_SESSIONS_USED: list[_FakeSession] = []
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
"""Wire fake reserve/finalize/release/session helpers."""
_SESSIONS_USED.clear()
reserve_calls: list[dict[str, Any]] = []
finalize_calls: list[dict[str, Any]] = []
release_calls: list[dict[str, Any]] = []
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
reserve_calls.append(
{
"user_id": user_id,
"reserve_micros": reserve_micros,
"request_id": request_id,
}
)
return reserve_result
async def _fake_finalize(
*, db_session, user_id, request_id, actual_micros, reserved_micros
):
finalize_calls.append(
{
"user_id": user_id,
"actual_micros": actual_micros,
"reserved_micros": reserved_micros,
}
)
return finalize_result or _FakeQuotaResult(allowed=True)
async def _fake_release(*, db_session, user_id, reserved_micros):
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
record_calls: list[dict[str, Any]] = []
async def _fake_record(session, **kwargs):
record_calls.append(kwargs)
return object()
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_reserve",
_fake_reserve,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_finalize",
_fake_finalize,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.TokenQuotaService.premium_release",
_fake_release,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.shielded_async_session",
_fake_shielded_session,
raising=False,
)
monkeypatch.setattr(
"app.services.billable_calls.record_token_usage",
_fake_record,
raising=False,
)
return {
"reserve": reserve_calls,
"finalize": finalize_calls,
"release": release_calls,
"record": record_calls,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openai/gpt-image-1",
usage_type="image_generation",
) as acc:
# Simulate a captured cost — the accumulator is fed by the LiteLLM
# callback in real life, here we add it manually.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=37_000,
call_kind="image_generation",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
# Free still audits.
assert len(spies["record"]) == 1
assert spies["record"][0]["usage_type"] == "image_generation"
assert spies["record"][0]["cost_micros"] == 37_000
@pytest.mark.asyncio
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "image_generation"
assert err.used_micros == 5_000_000
assert err.limit_micros == 5_000_000
assert err.remaining_micros == 0
# Reserve was attempted, but no finalize/release on a denied reserve
# — the reservation never actually held credit.
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
assert spies["release"] == []
# Denied premium calls do NOT create an audit row (no work happened).
assert spies["record"] == []
@pytest.mark.asyncio
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
) as acc:
# LiteLLM callback would normally fill this — simulate $0.04 image.
acc.add(
model="openai/gpt-image-1",
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost_micros=40_000,
call_kind="image_generation",
)
assert len(spies["reserve"]) == 1
assert spies["reserve"][0]["reserve_micros"] == 50_000
assert len(spies["finalize"]) == 1
assert spies["finalize"][0]["actual_micros"] == 40_000
assert spies["finalize"][0]["reserved_micros"] == 50_000
assert spies["release"] == []
# And audit row written with the actual debited cost.
assert spies["record"][0]["cost_micros"] == 40_000
# Each quota op opened its OWN session — proves session isolation.
assert len(_SESSIONS_USED) >= 3
# Sessions used should each have committed (or be the audit one which commits).
for _s in _SESSIONS_USED:
# finalize/reserve happen via TokenQuotaService.* which we stub —
# they don't actually call commit on our fake session, but the
# audit session does. We just assert >=1 session committed.
pass
assert any(s.committed for s in _SESSIONS_USED)
@pytest.mark.asyncio
async def test_premium_failure_releases_reservation(monkeypatch):
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
class _ProviderError(Exception):
pass
with pytest.raises(_ProviderError):
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="openai/gpt-image-1",
quota_reserve_micros_override=50_000,
usage_type="image_generation",
):
raise _ProviderError("OpenRouter 503")
assert len(spies["reserve"]) == 1
assert spies["finalize"] == []
# Failure path: release the held reservation.
assert len(spies["release"]) == 1
assert spies["release"][0]["reserved_micros"] == 50_000
@pytest.mark.asyncio
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
"""When ``quota_reserve_micros_override`` is None we fall back to
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
Vision LLM calls take this path (token-priced models).
"""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
captured_estimator_calls: list[dict[str, Any]] = []
def _fake_estimate(*, base_model, quota_reserve_tokens):
captured_estimator_calls.append(
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
)
return 12_345
monkeypatch.setattr(
"app.services.billable_calls.estimate_call_reserve_micros",
_fake_estimate,
raising=False,
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=1,
billing_tier="premium",
base_model="openai/gpt-4o",
quota_reserve_tokens=4000,
usage_type="vision_extraction",
):
pass
assert captured_estimator_calls == [
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
]
assert spies["reserve"][0]["reserve_micros"] == 12_345
# ---------------------------------------------------------------------------
# Podcast / video-presentation usage_type coverage
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
"""Free podcast configs must skip reserve/finalize but still emit a
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
have full audit coverage of free-tier agent runs."""
from app.services.billable_calls import billable_call
spies = _patch_isolation_layer(
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
)
user_id = uuid4()
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="free",
base_model="openrouter/some-free-model",
quota_reserve_micros_override=200_000,
usage_type="podcast_generation",
thread_id=99,
call_details={"podcast_id": 7, "title": "Test Podcast"},
) as acc:
# Two transcript LLM calls aggregated into one accumulator.
acc.add(
model="openrouter/some-free-model",
prompt_tokens=1500,
completion_tokens=8000,
total_tokens=9500,
cost_micros=0,
call_kind="chat",
)
assert spies["reserve"] == []
assert spies["finalize"] == []
assert spies["release"] == []
assert len(spies["record"]) == 1
row = spies["record"][0]
assert row["usage_type"] == "podcast_generation"
assert row["thread_id"] == 99
assert row["search_space_id"] == 42
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
@pytest.mark.asyncio
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
"""Premium video-presentation runs that hit a denied reservation must
raise ``QuotaInsufficientError`` *before* the graph runs and must not
emit an audit row (no work happened)."""
from app.services.billable_calls import (
QuotaInsufficientError,
billable_call,
)
spies = _patch_isolation_layer(
monkeypatch,
reserve_result=_FakeQuotaResult(
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
),
)
user_id = uuid4()
with pytest.raises(QuotaInsufficientError) as exc_info:
async with billable_call(
user_id=user_id,
search_space_id=42,
billing_tier="premium",
base_model="gpt-5.4",
quota_reserve_micros_override=1_000_000,
usage_type="video_presentation_generation",
thread_id=99,
call_details={"video_presentation_id": 12, "title": "Test Video"},
):
pytest.fail("body should not run when reserve is denied")
err = exc_info.value
assert err.usage_type == "video_presentation_generation"
assert err.remaining_micros == 500_000
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
assert spies["finalize"] == []
assert spies["release"] == []
assert spies["record"] == []