mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
432 lines
14 KiB
Python
432 lines
14 KiB
Python
"""Unit tests for the ``billable_call`` async context manager.
|
|
|
|
Covers the per-call premium-credit lifecycle for image generation and
|
|
vision LLM extraction:
|
|
|
|
* Free configs bypass reserve/finalize but still write an audit row.
|
|
* Premium reserve denial raises ``QuotaInsufficientError`` (HTTP 402 in the
|
|
route layer).
|
|
* Successful premium calls reserve, yield the accumulator, then finalize
|
|
with the LiteLLM-reported actual cost — and write an audit row.
|
|
* Failed premium calls release the reservation so credit isn't leaked.
|
|
* All quota DB ops happen inside their OWN ``shielded_async_session``,
|
|
isolating them from the caller's transaction (issue A).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fakes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeQuotaResult:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
allowed: bool,
|
|
used: int = 0,
|
|
limit: int = 5_000_000,
|
|
remaining: int = 5_000_000,
|
|
) -> None:
|
|
self.allowed = allowed
|
|
self.used = used
|
|
self.limit = limit
|
|
self.remaining = remaining
|
|
|
|
|
|
class _FakeSession:
|
|
"""Minimal AsyncSession stub — record commits for assertion."""
|
|
|
|
def __init__(self) -> None:
|
|
self.committed = False
|
|
self.added: list[Any] = []
|
|
|
|
def add(self, obj: Any) -> None:
|
|
self.added.append(obj)
|
|
|
|
async def commit(self) -> None:
|
|
self.committed = True
|
|
|
|
async def close(self) -> None:
|
|
pass
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _fake_shielded_session():
|
|
s = _FakeSession()
|
|
_SESSIONS_USED.append(s)
|
|
yield s
|
|
|
|
|
|
_SESSIONS_USED: list[_FakeSession] = []
|
|
|
|
|
|
def _patch_isolation_layer(monkeypatch, *, reserve_result, finalize_result=None):
|
|
"""Wire fake reserve/finalize/release/session helpers."""
|
|
_SESSIONS_USED.clear()
|
|
reserve_calls: list[dict[str, Any]] = []
|
|
finalize_calls: list[dict[str, Any]] = []
|
|
release_calls: list[dict[str, Any]] = []
|
|
|
|
async def _fake_reserve(*, db_session, user_id, request_id, reserve_micros):
|
|
reserve_calls.append(
|
|
{
|
|
"user_id": user_id,
|
|
"reserve_micros": reserve_micros,
|
|
"request_id": request_id,
|
|
}
|
|
)
|
|
return reserve_result
|
|
|
|
async def _fake_finalize(
|
|
*, db_session, user_id, request_id, actual_micros, reserved_micros
|
|
):
|
|
finalize_calls.append(
|
|
{
|
|
"user_id": user_id,
|
|
"actual_micros": actual_micros,
|
|
"reserved_micros": reserved_micros,
|
|
}
|
|
)
|
|
return finalize_result or _FakeQuotaResult(allowed=True)
|
|
|
|
async def _fake_release(*, db_session, user_id, reserved_micros):
|
|
release_calls.append({"user_id": user_id, "reserved_micros": reserved_micros})
|
|
|
|
record_calls: list[dict[str, Any]] = []
|
|
|
|
async def _fake_record(session, **kwargs):
|
|
record_calls.append(kwargs)
|
|
return object()
|
|
|
|
monkeypatch.setattr(
|
|
"app.services.billable_calls.TokenQuotaService.premium_reserve",
|
|
_fake_reserve,
|
|
raising=False,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.services.billable_calls.TokenQuotaService.premium_finalize",
|
|
_fake_finalize,
|
|
raising=False,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.services.billable_calls.TokenQuotaService.premium_release",
|
|
_fake_release,
|
|
raising=False,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.services.billable_calls.shielded_async_session",
|
|
_fake_shielded_session,
|
|
raising=False,
|
|
)
|
|
monkeypatch.setattr(
|
|
"app.services.billable_calls.record_token_usage",
|
|
_fake_record,
|
|
raising=False,
|
|
)
|
|
|
|
return {
|
|
"reserve": reserve_calls,
|
|
"finalize": finalize_calls,
|
|
"release": release_calls,
|
|
"record": record_calls,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_path_skips_reserve_but_writes_audit_row(monkeypatch):
|
|
from app.services.billable_calls import billable_call
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
|
)
|
|
user_id = uuid4()
|
|
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=42,
|
|
billing_tier="free",
|
|
base_model="openai/gpt-image-1",
|
|
usage_type="image_generation",
|
|
) as acc:
|
|
# Simulate a captured cost — the accumulator is fed by the LiteLLM
|
|
# callback in real life, here we add it manually.
|
|
acc.add(
|
|
model="openai/gpt-image-1",
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
cost_micros=37_000,
|
|
call_kind="image_generation",
|
|
)
|
|
|
|
assert spies["reserve"] == []
|
|
assert spies["finalize"] == []
|
|
assert spies["release"] == []
|
|
# Free still audits.
|
|
assert len(spies["record"]) == 1
|
|
assert spies["record"][0]["usage_type"] == "image_generation"
|
|
assert spies["record"][0]["cost_micros"] == 37_000
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_premium_reserve_denied_raises_quota_insufficient(monkeypatch):
|
|
from app.services.billable_calls import (
|
|
QuotaInsufficientError,
|
|
billable_call,
|
|
)
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch,
|
|
reserve_result=_FakeQuotaResult(
|
|
allowed=False, used=5_000_000, limit=5_000_000, remaining=0
|
|
),
|
|
)
|
|
user_id = uuid4()
|
|
|
|
with pytest.raises(QuotaInsufficientError) as exc_info:
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=42,
|
|
billing_tier="premium",
|
|
base_model="openai/gpt-image-1",
|
|
quota_reserve_micros_override=50_000,
|
|
usage_type="image_generation",
|
|
):
|
|
pytest.fail("body should not run when reserve is denied")
|
|
|
|
err = exc_info.value
|
|
assert err.usage_type == "image_generation"
|
|
assert err.used_micros == 5_000_000
|
|
assert err.limit_micros == 5_000_000
|
|
assert err.remaining_micros == 0
|
|
# Reserve was attempted, but no finalize/release on a denied reserve
|
|
# — the reservation never actually held credit.
|
|
assert len(spies["reserve"]) == 1
|
|
assert spies["finalize"] == []
|
|
assert spies["release"] == []
|
|
# Denied premium calls do NOT create an audit row (no work happened).
|
|
assert spies["record"] == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_premium_success_finalizes_with_actual_cost(monkeypatch):
|
|
from app.services.billable_calls import billable_call
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
|
)
|
|
user_id = uuid4()
|
|
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=42,
|
|
billing_tier="premium",
|
|
base_model="openai/gpt-image-1",
|
|
quota_reserve_micros_override=50_000,
|
|
usage_type="image_generation",
|
|
) as acc:
|
|
# LiteLLM callback would normally fill this — simulate $0.04 image.
|
|
acc.add(
|
|
model="openai/gpt-image-1",
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
cost_micros=40_000,
|
|
call_kind="image_generation",
|
|
)
|
|
|
|
assert len(spies["reserve"]) == 1
|
|
assert spies["reserve"][0]["reserve_micros"] == 50_000
|
|
assert len(spies["finalize"]) == 1
|
|
assert spies["finalize"][0]["actual_micros"] == 40_000
|
|
assert spies["finalize"][0]["reserved_micros"] == 50_000
|
|
assert spies["release"] == []
|
|
# And audit row written with the actual debited cost.
|
|
assert spies["record"][0]["cost_micros"] == 40_000
|
|
# Each quota op opened its OWN session — proves session isolation.
|
|
assert len(_SESSIONS_USED) >= 3
|
|
# Sessions used should each have committed (or be the audit one which commits).
|
|
for _s in _SESSIONS_USED:
|
|
# finalize/reserve happen via TokenQuotaService.* which we stub —
|
|
# they don't actually call commit on our fake session, but the
|
|
# audit session does. We just assert >=1 session committed.
|
|
pass
|
|
assert any(s.committed for s in _SESSIONS_USED)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_premium_failure_releases_reservation(monkeypatch):
|
|
from app.services.billable_calls import billable_call
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
|
)
|
|
user_id = uuid4()
|
|
|
|
class _ProviderError(Exception):
|
|
pass
|
|
|
|
with pytest.raises(_ProviderError):
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=42,
|
|
billing_tier="premium",
|
|
base_model="openai/gpt-image-1",
|
|
quota_reserve_micros_override=50_000,
|
|
usage_type="image_generation",
|
|
):
|
|
raise _ProviderError("OpenRouter 503")
|
|
|
|
assert len(spies["reserve"]) == 1
|
|
assert spies["finalize"] == []
|
|
# Failure path: release the held reservation.
|
|
assert len(spies["release"]) == 1
|
|
assert spies["release"][0]["reserved_micros"] == 50_000
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_premium_uses_estimator_when_no_micros_override(monkeypatch):
|
|
"""When ``quota_reserve_micros_override`` is None we fall back to
|
|
``estimate_call_reserve_micros(base_model, quota_reserve_tokens)``.
|
|
Vision LLM calls take this path (token-priced models).
|
|
"""
|
|
from app.services.billable_calls import billable_call
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
|
)
|
|
|
|
captured_estimator_calls: list[dict[str, Any]] = []
|
|
|
|
def _fake_estimate(*, base_model, quota_reserve_tokens):
|
|
captured_estimator_calls.append(
|
|
{"base_model": base_model, "quota_reserve_tokens": quota_reserve_tokens}
|
|
)
|
|
return 12_345
|
|
|
|
monkeypatch.setattr(
|
|
"app.services.billable_calls.estimate_call_reserve_micros",
|
|
_fake_estimate,
|
|
raising=False,
|
|
)
|
|
|
|
user_id = uuid4()
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=1,
|
|
billing_tier="premium",
|
|
base_model="openai/gpt-4o",
|
|
quota_reserve_tokens=4000,
|
|
usage_type="vision_extraction",
|
|
):
|
|
pass
|
|
|
|
assert captured_estimator_calls == [
|
|
{"base_model": "openai/gpt-4o", "quota_reserve_tokens": 4000}
|
|
]
|
|
assert spies["reserve"][0]["reserve_micros"] == 12_345
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Podcast / video-presentation usage_type coverage
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_free_podcast_path_audits_with_podcast_usage_type(monkeypatch):
|
|
"""Free podcast configs must skip reserve/finalize but still emit a
|
|
``TokenUsage`` row tagged ``usage_type='podcast_generation'`` so we
|
|
have full audit coverage of free-tier agent runs."""
|
|
from app.services.billable_calls import billable_call
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch, reserve_result=_FakeQuotaResult(allowed=True)
|
|
)
|
|
user_id = uuid4()
|
|
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=42,
|
|
billing_tier="free",
|
|
base_model="openrouter/some-free-model",
|
|
quota_reserve_micros_override=200_000,
|
|
usage_type="podcast_generation",
|
|
thread_id=99,
|
|
call_details={"podcast_id": 7, "title": "Test Podcast"},
|
|
) as acc:
|
|
# Two transcript LLM calls aggregated into one accumulator.
|
|
acc.add(
|
|
model="openrouter/some-free-model",
|
|
prompt_tokens=1500,
|
|
completion_tokens=8000,
|
|
total_tokens=9500,
|
|
cost_micros=0,
|
|
call_kind="chat",
|
|
)
|
|
|
|
assert spies["reserve"] == []
|
|
assert spies["finalize"] == []
|
|
assert spies["release"] == []
|
|
|
|
assert len(spies["record"]) == 1
|
|
row = spies["record"][0]
|
|
assert row["usage_type"] == "podcast_generation"
|
|
assert row["thread_id"] == 99
|
|
assert row["search_space_id"] == 42
|
|
assert row["call_details"] == {"podcast_id": 7, "title": "Test Podcast"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_premium_video_denial_raises_quota_insufficient(monkeypatch):
|
|
"""Premium video-presentation runs that hit a denied reservation must
|
|
raise ``QuotaInsufficientError`` *before* the graph runs and must not
|
|
emit an audit row (no work happened)."""
|
|
from app.services.billable_calls import (
|
|
QuotaInsufficientError,
|
|
billable_call,
|
|
)
|
|
|
|
spies = _patch_isolation_layer(
|
|
monkeypatch,
|
|
reserve_result=_FakeQuotaResult(
|
|
allowed=False, used=4_500_000, limit=5_000_000, remaining=500_000
|
|
),
|
|
)
|
|
user_id = uuid4()
|
|
|
|
with pytest.raises(QuotaInsufficientError) as exc_info:
|
|
async with billable_call(
|
|
user_id=user_id,
|
|
search_space_id=42,
|
|
billing_tier="premium",
|
|
base_model="gpt-5.4",
|
|
quota_reserve_micros_override=1_000_000,
|
|
usage_type="video_presentation_generation",
|
|
thread_id=99,
|
|
call_details={"video_presentation_id": 12, "title": "Test Video"},
|
|
):
|
|
pytest.fail("body should not run when reserve is denied")
|
|
|
|
err = exc_info.value
|
|
assert err.usage_type == "video_presentation_generation"
|
|
assert err.remaining_micros == 500_000
|
|
assert spies["reserve"][0]["reserve_micros"] == 1_000_000
|
|
assert spies["finalize"] == []
|
|
assert spies["release"] == []
|
|
assert spies["record"] == []
|