mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
feat: billing and credit management v2 (#429)
* feat: use mps generated correlation ID * chore: update pipecat submodule * feat: add credit purchase URL * feat: carve out billing page and show credit ledger * feat: deprecate dograh based quota tracking * fix: remove cost calculation from dograh codebase * fix: create mps account on migrate to v2 * chore: update pipecat
This commit is contained in:
parent
97d7103480
commit
1f1149f4d5
80 changed files with 3335 additions and 2057 deletions
|
|
@ -384,7 +384,7 @@ async def search_chunks(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
|
|
@ -392,17 +392,17 @@ async def search_chunks(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
if effective_config.embeddings:
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.constants import (
|
||||
DEFAULT_CAMPAIGN_RETRY_CONFIG,
|
||||
DEFAULT_ORG_CONCURRENCY_LIMIT,
|
||||
DEPLOYMENT_MODE,
|
||||
)
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
|
|
@ -55,6 +59,11 @@ from api.services.configuration.registry import (
|
|||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
|
|
@ -129,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
|
|||
telnyx_missing_webhook_public_key_count: int
|
||||
|
||||
|
||||
@router.get("/context", response_model=OrganizationContextResponse)
|
||||
async def get_current_organization_context(user: UserModel = Depends(get_user)):
|
||||
"""Return organization-scoped configuration signals owned by Dograh."""
|
||||
return await get_organization_context(user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/telephony-providers/metadata",
|
||||
response_model=TelephonyProvidersMetadataResponse,
|
||||
|
|
@ -349,6 +364,23 @@ async def migrate_model_configuration_v2(
|
|||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize MPS billing v2 account for organization {}: {}",
|
||||
organization_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to initialize MPS billing v2 account",
|
||||
)
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.reports import generate_usage_runs_report_csv
|
||||
from api.utils.artifacts import artifact_url
|
||||
|
|
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
|
|||
period_start: str
|
||||
period_end: str
|
||||
used_dograh_tokens: float
|
||||
quota_dograh_tokens: int
|
||||
percentage_used: float
|
||||
next_refresh_date: str
|
||||
quota_enabled: bool
|
||||
total_duration_seconds: int
|
||||
# New USD fields
|
||||
used_amount_usd: Optional[float] = None
|
||||
quota_amount_usd: Optional[float] = None
|
||||
currency: Optional[str] = None
|
||||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
|
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
|
|||
total_quota: float
|
||||
|
||||
|
||||
class MPSCreditPurchaseUrlResponse(BaseModel):
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class MPSBillingAccountResponse(BaseModel):
|
||||
id: int
|
||||
organization_id: int
|
||||
billing_mode: str
|
||||
cached_balance_credits: float
|
||||
currency: str
|
||||
|
||||
|
||||
class MPSCreditLedgerEntryResponse(BaseModel):
|
||||
id: int
|
||||
entry_type: str
|
||||
origin: Optional[str] = None
|
||||
credits_delta: float
|
||||
balance_after: float
|
||||
amount_minor: Optional[int] = None
|
||||
amount_currency: Optional[str] = None
|
||||
payment_order_id: Optional[int] = None
|
||||
metric_code: Optional[str] = None
|
||||
correlation_id: Optional[str] = None
|
||||
aggregation_key: Optional[str] = None
|
||||
usage_event_id: Optional[int] = None
|
||||
workflow_run_id: Optional[int] = None
|
||||
workflow_id: Optional[int] = None
|
||||
billable_quantity: Optional[float] = None
|
||||
quantity_unit: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class MPSBillingCreditsResponse(BaseModel):
|
||||
billing_version: Literal["legacy", "v2"]
|
||||
total_credits_used: float = 0.0
|
||||
remaining_credits: float = 0.0
|
||||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
total_count: int = 0
|
||||
page: int = 1
|
||||
limit: int = 50
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
def _optional_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
|
|||
|
||||
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
|
||||
async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
||||
"""Get current billing period usage for the user's organization."""
|
||||
"""Get current reporting-period usage for the user's organization."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
|
||||
|
|
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _get_mps_billing_account_status(
|
||||
user: UserModel, organization_id: int
|
||||
) -> Optional[dict]:
|
||||
return await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
|
||||
|
||||
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = float(usage.get("total_credits_used", 0.0))
|
||||
total_remaining = float(usage.get("remaining_credits", 0.0))
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="legacy",
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
total_count = int(ledger.get("total_count") or len(ledger_entries))
|
||||
response_limit = int(ledger.get("limit") or limit)
|
||||
total_pages = int(
|
||||
ledger.get("total_pages")
|
||||
or ((total_count + response_limit - 1) // response_limit)
|
||||
)
|
||||
workflow_ids_by_run_id: dict[int, int] = {}
|
||||
workflow_run_ids = {
|
||||
workflow_run_id
|
||||
for entry in ledger_entries
|
||||
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
|
||||
is not None
|
||||
}
|
||||
for workflow_run_id in workflow_run_ids:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if (
|
||||
workflow_run
|
||||
and workflow_run.workflow
|
||||
and workflow_run.workflow.organization_id == organization_id
|
||||
):
|
||||
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
|
||||
|
||||
balance = float(account.get("cached_balance_credits") or 0.0)
|
||||
total_debits = sum(
|
||||
abs(float(entry.get("credits_delta") or 0.0))
|
||||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
if ledger.get("total_debits_credits") is not None:
|
||||
total_debits = float(ledger["total_debits_credits"])
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
total_credits_used=total_debits,
|
||||
remaining_credits=balance,
|
||||
total_quota=balance + total_debits,
|
||||
account=MPSBillingAccountResponse(
|
||||
id=int(account["id"]),
|
||||
organization_id=int(account["organization_id"]),
|
||||
billing_mode=str(account["billing_mode"]),
|
||||
cached_balance_credits=balance,
|
||||
currency=str(account.get("currency") or "USD"),
|
||||
),
|
||||
ledger_entries=[
|
||||
MPSCreditLedgerEntryResponse(
|
||||
id=int(entry["id"]),
|
||||
entry_type=str(entry["entry_type"]),
|
||||
origin=entry.get("origin"),
|
||||
credits_delta=float(entry.get("credits_delta") or 0.0),
|
||||
balance_after=float(entry.get("balance_after") or 0.0),
|
||||
amount_minor=entry.get("amount_minor"),
|
||||
amount_currency=entry.get("amount_currency"),
|
||||
payment_order_id=entry.get("payment_order_id"),
|
||||
metric_code=entry.get("metric_code"),
|
||||
correlation_id=entry.get("correlation_id"),
|
||||
aggregation_key=entry.get("aggregation_key"),
|
||||
usage_event_id=_optional_int(entry.get("usage_event_id")),
|
||||
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
|
||||
workflow_id=workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None,
|
||||
billable_quantity=float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None,
|
||||
quantity_unit=entry.get("quantity_unit"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
total_count=total_count,
|
||||
page=int(ledger.get("page") or page),
|
||||
limit=response_limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch billing credits: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/mps-credits/purchase-url",
|
||||
response_model=MPSCreditPurchaseUrlResponse,
|
||||
)
|
||||
async def create_mps_credit_purchase_url(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credit purchases are not available in OSS mode",
|
||||
)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
assert organization_id is not None
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
"Credit purchases are available only for organizations using billing v2"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
session = await mps_service_key_client.create_credit_purchase_url(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
|
||||
billing_details={
|
||||
"source": "dograh_billing",
|
||||
"dograh_user_id": str(user.id),
|
||||
"dograh_provider_id": str(user.provider_id),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to create credit purchase URL",
|
||||
)
|
||||
|
||||
checkout_url = session.get("checkout_url")
|
||||
if not checkout_url:
|
||||
logger.error(f"MPS checkout session response missing checkout_url: {session}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="MPS checkout session response missing checkout_url",
|
||||
)
|
||||
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
FILTERS_DESCRIPTION = """\
|
||||
JSON-encoded array of filter objects. Each object has the shape:
|
||||
|
||||
|
|
|
|||
|
|
@ -41,12 +41,15 @@ from api.services.configuration.resolve import (
|
|||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.reports import generate_workflow_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.run_usage_response import (
|
||||
format_public_cost_info,
|
||||
format_public_usage_info,
|
||||
)
|
||||
from api.services.workflow.trigger_paths import (
|
||||
TriggerPathIssue,
|
||||
ensure_trigger_paths,
|
||||
|
|
@ -1053,13 +1056,15 @@ async def update_workflow(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
effective_config,
|
||||
)
|
||||
effective = resolve_effective_config(
|
||||
effective_config, enriched_overrides
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
|
|
@ -1264,22 +1269,7 @@ async def get_workflow_run(
|
|||
"transcript_public_url": artifact_url(public_access_token, "transcript"),
|
||||
"recording_public_url": artifact_url(public_access_token, "recording"),
|
||||
"public_access_token": public_access_token,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info and "dograh_token_usage" in run.cost_info
|
||||
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds"))
|
||||
)
|
||||
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
"created_at": run.created_at,
|
||||
"definition_id": run.definition_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue