mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: carve out billing page and show credit ledger
This commit is contained in:
parent
e33fec17db
commit
fde84387f2
13 changed files with 995 additions and 285 deletions
|
|
@ -55,6 +55,10 @@ from api.services.configuration.registry import (
|
|||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
|
|
@ -129,6 +133,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
|
|||
telnyx_missing_webhook_public_key_count: int
|
||||
|
||||
|
||||
@router.get("/context", response_model=OrganizationContextResponse)
|
||||
async def get_current_organization_context(user: UserModel = Depends(get_user)):
|
||||
"""Return organization-scoped configuration signals owned by Dograh."""
|
||||
return await get_organization_context(user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/telephony-providers/metadata",
|
||||
response_model=TelephonyProvidersMetadataResponse,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
|
@ -47,6 +47,36 @@ class MPSCreditPurchaseUrlResponse(BaseModel):
|
|||
checkout_url: str
|
||||
|
||||
|
||||
class MPSBillingAccountResponse(BaseModel):
|
||||
id: int
|
||||
organization_id: int
|
||||
billing_mode: str
|
||||
cached_balance_credits: float
|
||||
currency: str
|
||||
|
||||
|
||||
class MPSCreditLedgerEntryResponse(BaseModel):
|
||||
id: int
|
||||
entry_type: str
|
||||
origin: Optional[str] = None
|
||||
credits_delta: float
|
||||
balance_after: float
|
||||
amount_minor: Optional[int] = None
|
||||
amount_currency: Optional[str] = None
|
||||
payment_order_id: Optional[int] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class MPSBillingCreditsResponse(BaseModel):
|
||||
billing_version: Literal["legacy", "v2"]
|
||||
total_credits_used: float = 0.0
|
||||
remaining_credits: float = 0.0
|
||||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -149,6 +179,102 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _uses_mps_billing_v2(user: UserModel, organization_id: int) -> bool:
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return (
|
||||
resolved.source == "organization_v2"
|
||||
and resolved.effective.managed_service_version == 2
|
||||
)
|
||||
|
||||
|
||||
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = float(usage.get("total_credits_used", 0.0))
|
||||
total_remaining = float(usage.get("remaining_credits", 0.0))
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="legacy",
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
if not await _uses_mps_billing_v2(user, organization_id):
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
balance = float(account.get("cached_balance_credits") or 0.0)
|
||||
total_debits = sum(
|
||||
abs(float(entry.get("credits_delta") or 0.0))
|
||||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
total_credits_used=total_debits,
|
||||
remaining_credits=balance,
|
||||
total_quota=balance + total_debits,
|
||||
account=MPSBillingAccountResponse(
|
||||
id=int(account["id"]),
|
||||
organization_id=int(account["organization_id"]),
|
||||
billing_mode=str(account["billing_mode"]),
|
||||
cached_balance_credits=balance,
|
||||
currency=str(account.get("currency") or "USD"),
|
||||
),
|
||||
ledger_entries=[
|
||||
MPSCreditLedgerEntryResponse(
|
||||
id=int(entry["id"]),
|
||||
entry_type=str(entry["entry_type"]),
|
||||
origin=entry.get("origin"),
|
||||
credits_delta=float(entry.get("credits_delta") or 0.0),
|
||||
balance_after=float(entry.get("balance_after") or 0.0),
|
||||
amount_minor=entry.get("amount_minor"),
|
||||
amount_currency=entry.get("amount_currency"),
|
||||
payment_order_id=entry.get("payment_order_id"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch billing credits: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/mps-credits/purchase-url",
|
||||
response_model=MPSCreditPurchaseUrlResponse,
|
||||
|
|
@ -185,9 +311,9 @@ async def create_mps_credit_purchase_url(
|
|||
session = await mps_service_key_client.create_credit_purchase_url(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/reports",
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
|
||||
billing_details={
|
||||
"source": "dograh_reports",
|
||||
"source": "dograh_billing",
|
||||
"dograh_user_id": str(user.id),
|
||||
"dograh_provider_id": str(user.provider_id),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -390,6 +390,36 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get the MPS v2 billing account balance and recent credit ledger."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS credit ledger: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS credit ledger: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
50
api/services/organization_context.py
Normal file
50
api/services/organization_context.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
|
||||
class OrganizationModelServicesContext(BaseModel):
|
||||
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
has_model_configuration_v2: bool
|
||||
managed_service_version: Optional[int] = None
|
||||
uses_managed_service_v2: bool
|
||||
|
||||
|
||||
class OrganizationContextResponse(BaseModel):
|
||||
organization_id: Optional[int] = None
|
||||
organization_provider_id: Optional[str] = None
|
||||
model_services: OrganizationModelServicesContext
|
||||
|
||||
|
||||
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
|
||||
organization_id = user.selected_organization_id
|
||||
organization = (
|
||||
await db_client.get_organization_by_id(organization_id)
|
||||
if organization_id
|
||||
else None
|
||||
)
|
||||
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
managed_service_version = resolved.effective.managed_service_version
|
||||
|
||||
return OrganizationContextResponse(
|
||||
organization_id=organization_id,
|
||||
organization_provider_id=organization.provider_id if organization else None,
|
||||
model_services=OrganizationModelServicesContext(
|
||||
config_source=resolved.source,
|
||||
has_model_configuration_v2=resolved.source == "organization_v2",
|
||||
managed_service_version=managed_service_version,
|
||||
uses_managed_service_v2=(
|
||||
resolved.source == "organization_v2" and managed_service_version == 2
|
||||
),
|
||||
),
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue