""" Per-call billable wrapper for image generation, vision LLM extraction, and any other short-lived premium operation that must charge against the user's shared premium credit pool. The ``billable_call`` async context manager encapsulates the standard "reserve → execute → finalize / release → record audit row" lifecycle in a single primitive so callers (the image-generation REST route and the vision-LLM wrapper used during indexing) don't have to re-implement it. KEY DESIGN POINTS (issue A, B): 1. **Session isolation.** ``billable_call`` takes *no* ``db_session`` argument. All ``TokenQuotaService.premium_*`` calls and the audit-row insert each run inside their own ``shielded_async_session()``. This guarantees that a quota commit/rollback can never accidentally flush or roll back rows the caller has staged in the request's main session (e.g. a freshly-created ``ImageGeneration`` row). 2. **ContextVar safety.** The accumulator is scoped via :func:`scoped_turn` (which uses ``ContextVar.reset(token)``), so a nested ``billable_call`` inside an outer chat turn cannot corrupt the chat turn's accumulator. 3. **Free configs are still audited.** Free calls bypass the reserve / finalize dance entirely but still record a ``TokenUsage`` audit row with the LiteLLM-reported ``cost_micros``. This keeps the cost-attribution pipeline complete for analytics even when nothing is debited. 4. **Quota denial raises ``QuotaInsufficientError``.** The route handler is responsible for translating that into HTTP 402. We *do not* catch the denial inside ``billable_call`` — letting it propagate also prevents the image-generation route from creating an ``ImageGeneration`` row for a request that never actually ran. """ from __future__ import annotations import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any from uuid import UUID, uuid4 from sqlalchemy.ext.asyncio import AsyncSession from app.config import config from app.db import shielded_async_session from app.services.token_quota_service import ( TokenQuotaService, estimate_call_reserve_micros, ) from app.services.token_tracking_service import ( TurnTokenAccumulator, record_token_usage, scoped_turn, ) logger = logging.getLogger(__name__) class QuotaInsufficientError(Exception): """Raised when ``TokenQuotaService.premium_reserve`` denies a billable call because the user has exhausted their premium credit pool. The route handler should catch this and return HTTP 402 Payment Required (or the equivalent for the surface area). Outside of the HTTP layer (e.g. the ``QuotaCheckedVisionLLM`` wrapper used during indexing) callers may catch this and degrade gracefully — e.g. fall back to OCR when vision is unavailable. """ def __init__( self, *, usage_type: str, used_micros: int, limit_micros: int, remaining_micros: int, ) -> None: self.usage_type = usage_type self.used_micros = used_micros self.limit_micros = limit_micros self.remaining_micros = remaining_micros super().__init__( f"Premium credit exhausted for {usage_type}: " f"used={used_micros} limit={limit_micros} remaining={remaining_micros} (micro-USD)" ) @asynccontextmanager async def billable_call( *, user_id: UUID, search_space_id: int, billing_tier: str, base_model: str, quota_reserve_tokens: int | None = None, quota_reserve_micros_override: int | None = None, usage_type: str, thread_id: int | None = None, message_id: int | None = None, call_details: dict[str, Any] | None = None, ) -> AsyncIterator[TurnTokenAccumulator]: """Wrap a single billable LLM/image call. Args: user_id: Owner of the credit pool to debit. For vision-LLM during indexing this is the *search-space owner* (issue M), not the triggering user. search_space_id: Required — recorded on the ``TokenUsage`` audit row. billing_tier: ``"premium"`` debits; anything else (``"free"``) skips the reserve/finalize dance but still records an audit row with the captured cost. base_model: Used by :func:`estimate_call_reserve_micros` to compute a worst-case reservation from LiteLLM's pricing table. quota_reserve_tokens: Optional per-config override for the chat-style reserve estimator (vision LLM uses this). quota_reserve_micros_override: Optional flat micro-USD reservation (image generation uses this — its cost shape is per-image, not per-token). usage_type: ``"image_generation"`` / ``"vision_extraction"`` / etc. Recorded on the ``TokenUsage`` row. thread_id, message_id: Optional FK columns on ``TokenUsage``. call_details: Optional per-call metadata (model name, parameters) forwarded to ``record_token_usage``. Yields: The ``TurnTokenAccumulator`` scoped to this call. The caller invokes the underlying LLM/image API while inside the ``async with``; the ``TokenTrackingCallback`` populates the accumulator automatically. Raises: QuotaInsufficientError: when premium and ``premium_reserve`` denies. """ is_premium = billing_tier == "premium" async with scoped_turn() as acc: # ---------- Free path: just audit ------------------------------- if not is_premium: try: yield acc finally: # Always audit, even on exception, so we capture cost when # provider returns successfully but the caller raises later. try: async with shielded_async_session() as audit_session: await record_token_usage( audit_session, usage_type=usage_type, search_space_id=search_space_id, user_id=user_id, prompt_tokens=acc.total_prompt_tokens, completion_tokens=acc.total_completion_tokens, total_tokens=acc.grand_total, cost_micros=acc.total_cost_micros, model_breakdown=acc.per_message_summary(), call_details=call_details, thread_id=thread_id, message_id=message_id, ) await audit_session.commit() except Exception: logger.exception( "[billable_call] free-path audit insert failed for " "usage_type=%s user_id=%s", usage_type, user_id, ) return # ---------- Premium path: reserve → execute → finalize ---------- if quota_reserve_micros_override is not None: reserve_micros = max(1, int(quota_reserve_micros_override)) else: reserve_micros = estimate_call_reserve_micros( base_model=base_model or "", quota_reserve_tokens=quota_reserve_tokens, ) request_id = str(uuid4()) async with shielded_async_session() as quota_session: reserve_result = await TokenQuotaService.premium_reserve( db_session=quota_session, user_id=user_id, request_id=request_id, reserve_micros=reserve_micros, ) if not reserve_result.allowed: logger.info( "[billable_call] reserve DENIED user=%s usage_type=%s " "reserve=%d used=%d limit=%d remaining=%d", user_id, usage_type, reserve_micros, reserve_result.used, reserve_result.limit, reserve_result.remaining, ) raise QuotaInsufficientError( usage_type=usage_type, used_micros=reserve_result.used, limit_micros=reserve_result.limit, remaining_micros=reserve_result.remaining, ) logger.info( "[billable_call] reserve OK user=%s usage_type=%s reserve_micros=%d " "(remaining=%d)", user_id, usage_type, reserve_micros, reserve_result.remaining, ) try: yield acc except BaseException: # Release on any failure (including QuotaInsufficientError raised # from a downstream call, asyncio cancellation, etc.). We use # BaseException so cancellation also releases. try: async with shielded_async_session() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, reserved_micros=reserve_micros, ) except Exception: logger.exception( "[billable_call] premium_release failed for user=%s " "reserve_micros=%d (reservation will be GC'd by quota " "reconciliation if/when implemented)", user_id, reserve_micros, ) raise # ---------- Success: finalize + audit ---------------------------- actual_micros = acc.total_cost_micros try: async with shielded_async_session() as quota_session: final_result = await TokenQuotaService.premium_finalize( db_session=quota_session, user_id=user_id, request_id=request_id, actual_micros=actual_micros, reserved_micros=reserve_micros, ) logger.info( "[billable_call] finalize user=%s usage_type=%s actual=%d " "reserved=%d → used=%d/%d (remaining=%d)", user_id, usage_type, actual_micros, reserve_micros, final_result.used, final_result.limit, final_result.remaining, ) except Exception: # Last-ditch: if finalize itself fails, we must at least release # so the reservation doesn't leak. logger.exception( "[billable_call] premium_finalize failed for user=%s; " "attempting release", user_id, ) try: async with shielded_async_session() as quota_session: await TokenQuotaService.premium_release( db_session=quota_session, user_id=user_id, reserved_micros=reserve_micros, ) except Exception: logger.exception( "[billable_call] release after finalize failure ALSO failed " "for user=%s", user_id, ) try: async with shielded_async_session() as audit_session: await record_token_usage( audit_session, usage_type=usage_type, search_space_id=search_space_id, user_id=user_id, prompt_tokens=acc.total_prompt_tokens, completion_tokens=acc.total_completion_tokens, total_tokens=acc.grand_total, cost_micros=actual_micros, model_breakdown=acc.per_message_summary(), call_details=call_details, thread_id=thread_id, message_id=message_id, ) await audit_session.commit() except Exception: logger.exception( "[billable_call] premium-path audit insert failed for " "usage_type=%s user_id=%s (debit was applied)", usage_type, user_id, ) async def _resolve_agent_billing_for_search_space( session: AsyncSession, search_space_id: int, *, thread_id: int | None = None, ) -> tuple[UUID, str, str]: """Resolve ``(owner_user_id, billing_tier, base_model)`` for the search-space agent LLM. Used by Celery tasks (podcast generation, video presentation) to bill the search-space owner's premium credit pool when the agent LLM is premium. Resolution rules mirror chat at ``stream_new_chat.py:2294-2351``: - Search space not found / no ``agent_llm_id``: raise ``ValueError``. - **Auto mode** (``id == AUTO_FASTEST_ID == 0``): * ``thread_id`` is set: delegate to ``resolve_or_get_pinned_llm_config_id`` (the same call chat uses) and recurse into the resolved id. Reuses chat's existing pin if present so the same model bills for chat + downstream podcast/video. If the user is not premium-eligible, the pin service auto-restricts to free deployments — denial only happens later in ``billable_call.premium_reserve`` if the pin really is premium and credit ran out mid-flow. * ``thread_id`` is None: fallback to ``("free", "auto")``. Forward-compat for any future direct-API path; today both Celery tasks always pass ``thread_id``. - **Negative id** (global YAML / OpenRouter): ``cfg["billing_tier"]`` (defaults to ``"free"`` via ``app/config/__init__.py:52`` setdefault), ``base_model = litellm_params.get("base_model") or model_name`` — NOT provider-prefixed, matching chat's cost-map lookup convention. - **Positive id** (user BYOK ``NewLLMConfig``): always free (matches ``AgentConfig.from_new_llm_config`` which hard-codes ``billing_tier="free"``); ``base_model`` from ``litellm_params`` or ``model_name``. Note on imports: ``llm_service``, ``auto_model_pin_service``, and ``llm_router_service`` are imported lazily inside the function body to avoid hoisting litellm side-effects (``litellm.callbacks = [token_tracker]``, ``litellm.drop_params``, etc.) into ``billable_calls.py``'s module load path. """ from sqlalchemy import select from app.db import NewLLMConfig, SearchSpace result = await session.execute( select(SearchSpace).where(SearchSpace.id == search_space_id) ) search_space = result.scalars().first() if search_space is None: raise ValueError(f"Search space {search_space_id} not found") agent_llm_id = search_space.agent_llm_id if agent_llm_id is None: raise ValueError( f"Search space {search_space_id} has no agent_llm_id configured" ) owner_user_id: UUID = search_space.user_id from app.services.auto_model_pin_service import ( AUTO_FASTEST_ID, resolve_or_get_pinned_llm_config_id, ) if agent_llm_id == AUTO_FASTEST_ID: if thread_id is None: return owner_user_id, "free", "auto" try: resolution = await resolve_or_get_pinned_llm_config_id( session, thread_id=thread_id, search_space_id=search_space_id, user_id=str(owner_user_id), selected_llm_config_id=AUTO_FASTEST_ID, ) except ValueError: logger.warning( "[agent_billing] Auto-mode pin resolution failed for " "search_space=%s thread=%s; falling back to free", search_space_id, thread_id, exc_info=True, ) return owner_user_id, "free", "auto" agent_llm_id = resolution.resolved_llm_config_id if agent_llm_id < 0: from app.services.llm_service import get_global_llm_config cfg = get_global_llm_config(agent_llm_id) or {} billing_tier = str(cfg.get("billing_tier", "free")).lower() litellm_params = cfg.get("litellm_params") or {} base_model = litellm_params.get("base_model") or cfg.get("model_name") or "" return owner_user_id, billing_tier, base_model nlc_result = await session.execute( select(NewLLMConfig).where( NewLLMConfig.id == agent_llm_id, NewLLMConfig.search_space_id == search_space_id, ) ) nlc = nlc_result.scalars().first() base_model = "" if nlc is not None: litellm_params = nlc.litellm_params or {} base_model = litellm_params.get("base_model") or nlc.model_name or "" return owner_user_id, "free", base_model __all__ = [ "QuotaInsufficientError", "_resolve_agent_billing_for_search_space", "billable_call", ] # Re-export the config knob so callers don't have to import config just for # the default image reserve. DEFAULT_IMAGE_RESERVE_MICROS = config.QUOTA_DEFAULT_IMAGE_RESERVE_MICROS