feat: billing and credit management v2 (#429)

* feat: use mps generated correlation ID

* chore: update pipecat submodule

* feat: add credit purchase URL

* feat: carve out billing page and show credit ledger

* feat: deprecate dograh based quota tracking

* fix: remove cost calculation from dograh codebase

* fix: create mps account on migrate to v2

* chore: update pipecat
This commit is contained in:
Abhishek 2026-06-12 14:55:30 +05:30 committed by GitHub
parent 97d7103480
commit 1f1149f4d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 3335 additions and 2057 deletions

View file

@ -1,16 +1,16 @@
import json
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from loguru import logger
from pydantic import BaseModel, Field
from api.constants import DEPLOYMENT_MODE
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.mps_service_key_client import mps_service_key_client
from api.services.reports import generate_usage_runs_report_csv
from api.utils.artifacts import artifact_url
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
period_start: str
period_end: str
used_dograh_tokens: float
quota_dograh_tokens: int
percentage_used: float
next_refresh_date: str
quota_enabled: bool
total_duration_seconds: int
# New USD fields
used_amount_usd: Optional[float] = None
quota_amount_usd: Optional[float] = None
currency: Optional[str] = None
price_per_second_usd: Optional[float] = None
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
total_quota: float
class MPSCreditPurchaseUrlResponse(BaseModel):
checkout_url: str
class MPSBillingAccountResponse(BaseModel):
id: int
organization_id: int
billing_mode: str
cached_balance_credits: float
currency: str
class MPSCreditLedgerEntryResponse(BaseModel):
id: int
entry_type: str
origin: Optional[str] = None
credits_delta: float
balance_after: float
amount_minor: Optional[int] = None
amount_currency: Optional[str] = None
payment_order_id: Optional[int] = None
metric_code: Optional[str] = None
correlation_id: Optional[str] = None
aggregation_key: Optional[str] = None
usage_event_id: Optional[int] = None
workflow_run_id: Optional[int] = None
workflow_id: Optional[int] = None
billable_quantity: Optional[float] = None
quantity_unit: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: str
class MPSBillingCreditsResponse(BaseModel):
billing_version: Literal["legacy", "v2"]
total_credits_used: float = 0.0
remaining_credits: float = 0.0
total_quota: float = 0.0
account: Optional[MPSBillingAccountResponse] = None
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
total_count: int = 0
page: int = 1
limit: int = 50
total_pages: int = 0
def _optional_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
class WorkflowRunUsageResponse(BaseModel):
id: int
workflow_id: int
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
async def get_current_period_usage(user: UserModel = Depends(get_user)):
"""Get current billing period usage for the user's organization."""
"""Get current reporting-period usage for the user's organization."""
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
raise HTTPException(status_code=500, detail=str(e))
async def _get_mps_billing_account_status(
user: UserModel, organization_id: int
) -> Optional[dict]:
return await mps_service_key_client.get_billing_account_status(
organization_id=organization_id,
created_by=str(user.provider_id),
)
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
return bool(account and account.get("billing_mode") == "v2")
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
if DEPLOYMENT_MODE == "oss":
usage = await mps_service_key_client.get_usage_by_created_by(
str(user.provider_id)
)
else:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
usage = await mps_service_key_client.get_usage_by_organization(
user.selected_organization_id
)
total_used = float(usage.get("total_credits_used", 0.0))
total_remaining = float(usage.get("remaining_credits", 0.0))
return MPSBillingCreditsResponse(
billing_version="legacy",
total_credits_used=total_used,
remaining_credits=total_remaining,
total_quota=total_used + total_remaining,
)
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
async def get_billing_credits(
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
user: UserModel = Depends(get_user),
):
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
try:
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
return await _legacy_mps_credits_response(user)
organization_id = user.selected_organization_id
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
return await _legacy_mps_credits_response(user)
ledger = await mps_service_key_client.get_credit_ledger(
organization_id=organization_id,
page=page,
limit=limit,
created_by=str(user.provider_id),
)
account = ledger.get("account") or {}
ledger_entries = ledger.get("ledger_entries") or []
total_count = int(ledger.get("total_count") or len(ledger_entries))
response_limit = int(ledger.get("limit") or limit)
total_pages = int(
ledger.get("total_pages")
or ((total_count + response_limit - 1) // response_limit)
)
workflow_ids_by_run_id: dict[int, int] = {}
workflow_run_ids = {
workflow_run_id
for entry in ledger_entries
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
is not None
}
for workflow_run_id in workflow_run_ids:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if (
workflow_run
and workflow_run.workflow
and workflow_run.workflow.organization_id == organization_id
):
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
balance = float(account.get("cached_balance_credits") or 0.0)
total_debits = sum(
abs(float(entry.get("credits_delta") or 0.0))
for entry in ledger_entries
if float(entry.get("credits_delta") or 0.0) < 0
)
if ledger.get("total_debits_credits") is not None:
total_debits = float(ledger["total_debits_credits"])
return MPSBillingCreditsResponse(
billing_version="v2",
total_credits_used=total_debits,
remaining_credits=balance,
total_quota=balance + total_debits,
account=MPSBillingAccountResponse(
id=int(account["id"]),
organization_id=int(account["organization_id"]),
billing_mode=str(account["billing_mode"]),
cached_balance_credits=balance,
currency=str(account.get("currency") or "USD"),
),
ledger_entries=[
MPSCreditLedgerEntryResponse(
id=int(entry["id"]),
entry_type=str(entry["entry_type"]),
origin=entry.get("origin"),
credits_delta=float(entry.get("credits_delta") or 0.0),
balance_after=float(entry.get("balance_after") or 0.0),
amount_minor=entry.get("amount_minor"),
amount_currency=entry.get("amount_currency"),
payment_order_id=entry.get("payment_order_id"),
metric_code=entry.get("metric_code"),
correlation_id=entry.get("correlation_id"),
aggregation_key=entry.get("aggregation_key"),
usage_event_id=_optional_int(entry.get("usage_event_id")),
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
workflow_id=workflow_ids_by_run_id.get(
_optional_int(entry.get("workflow_run_id"))
)
if entry.get("workflow_run_id") is not None
else None,
billable_quantity=float(entry["billable_quantity"])
if entry.get("billable_quantity") is not None
else None,
quantity_unit=entry.get("quantity_unit"),
metadata=entry.get("metadata") or {},
created_at=str(entry["created_at"]),
)
for entry in ledger_entries
],
total_count=total_count,
page=int(ledger.get("page") or page),
limit=response_limit,
total_pages=total_pages,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to fetch billing credits: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post(
"/usage/mps-credits/purchase-url",
response_model=MPSCreditPurchaseUrlResponse,
)
async def create_mps_credit_purchase_url(
user: UserModel = Depends(get_user_with_selected_organization),
):
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
if DEPLOYMENT_MODE == "oss":
raise HTTPException(
status_code=404,
detail="Credit purchases are not available in OSS mode",
)
organization_id = user.selected_organization_id
assert organization_id is not None
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
raise HTTPException(
status_code=403,
detail=(
"Credit purchases are available only for organizations using billing v2"
),
)
try:
session = await mps_service_key_client.create_credit_purchase_url(
organization_id=organization_id,
created_by=str(user.provider_id),
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
billing_details={
"source": "dograh_billing",
"dograh_user_id": str(user.id),
"dograh_provider_id": str(user.provider_id),
},
)
except Exception as exc:
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
raise HTTPException(
status_code=502,
detail="Failed to create credit purchase URL",
)
checkout_url = session.get("checkout_url")
if not checkout_url:
logger.error(f"MPS checkout session response missing checkout_url: {session}")
raise HTTPException(
status_code=502,
detail="MPS checkout session response missing checkout_url",
)
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
FILTERS_DESCRIPTION = """\
JSON-encoded array of filter objects. Each object has the shape: