fix: remove cost calculation from dograh codebase

This commit is contained in:
Abhishek Kumar 2026-06-12 13:26:33 +05:30
parent 7d4e2e06a9
commit 8f241b89d2
39 changed files with 1067 additions and 1460 deletions

View file

@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.workflow.run_usage_response import format_public_cost_info
class CampaignClient(BaseDBClient):
@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient):
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(
run.cost_info, run.usage_info
),
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient):
async with self.async_session() as session:
conditions = [
WorkflowRunModel.is_completed.is_(True),
WorkflowRunModel.cost_info["call_duration_seconds"]
WorkflowRunModel.usage_info["call_duration_seconds"]
.as_string()
.isnot(None),
]
@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient):
WorkflowRunModel.initial_context,
WorkflowRunModel.gathered_context,
WorkflowRunModel.cost_info,
WorkflowRunModel.usage_info,
WorkflowRunModel.public_access_token,
)
.where(*conditions)

View file

@ -25,7 +25,7 @@ def get_workflow_run_order_clause(
"""
# Determine sort column
if sort_by == "duration":
sort_column = WorkflowRunModel.cost_info.op("->>")(
sort_column = WorkflowRunModel.usage_info.op("->>")(
"call_duration_seconds"
).cast(Float)
else:
@ -43,7 +43,7 @@ def get_workflow_run_order_clause(
ATTRIBUTE_FIELD_MAPPING = {
"dateRange": "created_at",
"dispositionCode": "gathered_context.mapped_call_disposition",
"duration": "cost_info.call_duration_seconds",
"duration": "usage_info.call_duration_seconds",
"status": "is_completed",
"tokenUsage": "cost_info.total_cost_usd",
"runId": "id",
@ -208,7 +208,7 @@ def apply_workflow_run_filters(
min_val = value.get("min")
max_val = value.get("max")
if field == "cost_info.call_duration_seconds":
if field == "usage_info.call_duration_seconds":
# Use ->> operator for compatibility with all PostgreSQL versions
# (subscript [] only works in PostgreSQL 14+)
duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")(

View file

@ -96,43 +96,6 @@ class OrganizationUsageClient(BaseDBClient):
)
return cycle_result.scalar_one()
async def update_usage_after_run(
self,
organization_id: int,
actual_tokens: float,
duration_seconds: float = 0,
charge_usd: float | None = None,
) -> None:
"""Update usage after a workflow run completes with actual token count and duration.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Acquire a row-level lock for atomic update
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(OrganizationUsageCycleModel.id == cycle.id)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one()
# Update usage atomically
cycle_locked.used_dograh_tokens += actual_tokens
cycle_locked.total_duration_seconds += int(round(duration_seconds))
# Update USD amount if provided
if charge_usd is not None:
if cycle_locked.used_amount_usd is None:
cycle_locked.used_amount_usd = 0
cycle_locked.used_amount_usd += charge_usd
await session.commit()
async def get_current_usage(self, organization_id: int) -> dict:
"""Get current reporting-period usage information."""
async with self.async_session() as session:
@ -178,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient):
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
WorkflowRunModel.usage_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
@ -231,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient):
total_tokens = 0
total_duration_seconds = 0
for run in runs:
if run.cost_info:
# Try to get dograh_token_usage first (new format)
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
# If not present, calculate from total_cost_usd (old format)
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
dograh_tokens = round(
float(run.cost_info["total_cost_usd"]) * 100, 2
)
# Get call duration
call_duration = run.cost_info.get("call_duration_seconds", 0)
else:
dograh_tokens = 0
call_duration = 0
dograh_tokens = 0
call_duration = (run.usage_info or {}).get("call_duration_seconds", 0)
total_tokens += dograh_tokens
total_duration_seconds += int(round(call_duration))
@ -317,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient):
WorkflowRunModel.initial_context,
WorkflowRunModel.gathered_context,
WorkflowRunModel.cost_info,
WorkflowRunModel.usage_info,
WorkflowRunModel.public_access_token,
)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
WorkflowRunModel.usage_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
@ -418,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient):
select(
date_expr.label("date"),
func.sum(
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
WorkflowRunModel.usage_info["call_duration_seconds"].as_float()
).label("total_seconds"),
func.count(WorkflowRunModel.id).label("call_count"),
)

View file

@ -16,6 +16,7 @@ from api.db.models import (
)
from api.enums import CallType, StorageBackend
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.workflow.run_usage_response import format_public_cost_info
class WorkflowRunClient(BaseDBClient):
@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient):
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(
run.cost_info, run.usage_info
),
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,

View file

@ -11,9 +11,6 @@ from api.constants import DEPLOYMENT_MODE, UI_APP_URL
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
from api.services.mps_service_key_client import mps_service_key_client
from api.services.reports import generate_usage_runs_report_csv
from api.utils.artifacts import artifact_url
@ -58,6 +55,14 @@ class MPSCreditLedgerEntryResponse(BaseModel):
amount_minor: Optional[int] = None
amount_currency: Optional[str] = None
payment_order_id: Optional[int] = None
metric_code: Optional[str] = None
correlation_id: Optional[str] = None
aggregation_key: Optional[str] = None
usage_event_id: Optional[int] = None
workflow_run_id: Optional[int] = None
workflow_id: Optional[int] = None
billable_quantity: Optional[float] = None
quantity_unit: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: str
@ -71,6 +76,15 @@ class MPSBillingCreditsResponse(BaseModel):
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
def _optional_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
class WorkflowRunUsageResponse(BaseModel):
id: int
workflow_id: int
@ -173,15 +187,17 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
raise HTTPException(status_code=500, detail=str(e))
async def _uses_mps_billing_v2(user: UserModel, organization_id: int) -> bool:
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
async def _get_mps_billing_account_status(
user: UserModel, organization_id: int
) -> Optional[dict]:
return await mps_service_key_client.get_billing_account_status(
organization_id=organization_id,
created_by=str(user.provider_id),
)
return (
resolved.source == "organization_v2"
and resolved.effective.managed_service_version == 2
)
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
return bool(account and account.get("billing_mode") == "v2")
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
@ -217,7 +233,8 @@ async def get_billing_credits(
return await _legacy_mps_credits_response(user)
organization_id = user.selected_organization_id
if not await _uses_mps_billing_v2(user, organization_id):
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
return await _legacy_mps_credits_response(user)
ledger = await mps_service_key_client.get_credit_ledger(
@ -227,6 +244,22 @@ async def get_billing_credits(
)
account = ledger.get("account") or {}
ledger_entries = ledger.get("ledger_entries") or []
workflow_ids_by_run_id: dict[int, int] = {}
workflow_run_ids = {
workflow_run_id
for entry in ledger_entries
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
is not None
}
for workflow_run_id in workflow_run_ids:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if (
workflow_run
and workflow_run.workflow
and workflow_run.workflow.organization_id == organization_id
):
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
balance = float(account.get("cached_balance_credits") or 0.0)
total_debits = sum(
abs(float(entry.get("credits_delta") or 0.0))
@ -256,6 +289,20 @@ async def get_billing_credits(
amount_minor=entry.get("amount_minor"),
amount_currency=entry.get("amount_currency"),
payment_order_id=entry.get("payment_order_id"),
metric_code=entry.get("metric_code"),
correlation_id=entry.get("correlation_id"),
aggregation_key=entry.get("aggregation_key"),
usage_event_id=_optional_int(entry.get("usage_event_id")),
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
workflow_id=workflow_ids_by_run_id.get(
_optional_int(entry.get("workflow_run_id"))
)
if entry.get("workflow_run_id") is not None
else None,
billable_quantity=float(entry["billable_quantity"])
if entry.get("billable_quantity") is not None
else None,
quantity_unit=entry.get("quantity_unit"),
metadata=entry.get("metadata") or {},
created_at=str(entry["created_at"]),
)
@ -285,19 +332,12 @@ async def create_mps_credit_purchase_url(
organization_id = user.selected_organization_id
assert organization_id is not None
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=organization_id,
)
if (
resolved.source != "organization_v2"
or resolved.effective.managed_service_version != 2
):
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
raise HTTPException(
status_code=403,
detail=(
"Credit purchases are available only for organizations using "
"Dograh managed model configuration v2"
"Credit purchases are available only for organizations using billing v2"
),
)

View file

@ -41,12 +41,15 @@ from api.services.configuration.resolve import (
)
from api.services.mps_service_key_client import mps_service_key_client
from api.services.posthog_client import capture_event
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.reports import generate_workflow_report_csv
from api.services.storage import storage_fs
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
from api.services.workflow.duplicate import duplicate_workflow
from api.services.workflow.errors import ItemKind, WorkflowError
from api.services.workflow.run_usage_response import (
format_public_cost_info,
format_public_usage_info,
)
from api.services.workflow.trigger_paths import (
TriggerPathIssue,
ensure_trigger_paths,
@ -1266,22 +1269,7 @@ async def get_workflow_run(
"transcript_public_url": artifact_url(public_access_token, "transcript"),
"recording_public_url": artifact_url(public_access_token, "recording"),
"public_access_token": public_access_token,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info and "dograh_token_usage" in run.cost_info
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds"))
)
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
"usage_info": format_public_usage_info(run.usage_info),
"created_at": run.created_at,
"definition_id": run.definition_id,

View file

@ -75,21 +75,21 @@ class UserConfigurationValidator:
status_list = []
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
# Realtime is optional - only validate if is_realtime is enabled
if configuration.is_realtime:
status_list.extend(
self._validate_service(
configuration.realtime, "realtime", required=True
)
)
else:
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)

View file

@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
Service keys are stored and managed entirely in MPS, not in the local database.
"""
import asyncio
from typing import List, Optional
import httpx
@ -420,6 +421,34 @@ class MPSServiceKeyClient:
response=response,
)
async def get_billing_account_status(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Get an existing MPS v2 billing account without creating one."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS billing account status: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS billing account status: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,
@ -454,6 +483,76 @@ class MPSServiceKeyClient:
response=response,
)
async def report_platform_usage(
self,
*,
organization_id: int,
correlation_id: Optional[str] = None,
duration_seconds: Optional[float] = None,
workflow_run_id: int | None = None,
metadata: Optional[dict] = None,
max_attempts: int = 3,
) -> dict:
"""Report hosted Dograh platform usage for a completed workflow run."""
if DEPLOYMENT_MODE == "oss":
raise ValueError("OSS deployments must not report platform usage to MPS")
if not correlation_id and duration_seconds is None:
raise ValueError(
"Platform usage reports require correlation_id or duration_seconds"
)
payload: dict = {
"metadata": metadata or {},
}
if correlation_id:
payload["correlation_id"] = correlation_id
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
max_attempts = max(1, max_attempts)
last_response: httpx.Response | None = None
async with httpx.AsyncClient(timeout=self.timeout) as client:
for attempt in range(1, max_attempts + 1):
response = await client.post(
(
f"{self.base_url}/api/v1/billing/accounts/"
f"{organization_id}/platform-usage"
),
json=payload,
headers=self._get_headers(organization_id=organization_id),
)
last_response = response
if response.status_code == 200:
return response.json()
should_retry = (
response.status_code == 409
and "usage_not_ready" in response.text
and attempt < max_attempts
)
if should_retry:
await asyncio.sleep(attempt)
continue
logger.error(
"Failed to report platform usage: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to report platform usage: {response.text}",
request=response.request,
response=response,
)
raise httpx.HTTPStatusError(
"Failed to report platform usage",
request=last_response.request,
response=last_response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -162,15 +162,13 @@ async def run_pipeline_telephony(
workflow_id: Workflow being executed.
workflow_run_id: Workflow run row.
user_id: Owner of the workflow.
call_id: Provider call identifier (stored in cost_info for billing).
call_id: Provider call identifier.
transport_kwargs: Provider-specific kwargs forwarded to the transport
factory (e.g. stream_sid + call_sid for Twilio).
"""
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
set_current_run_id(workflow_run_id)
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
workflow = await db_client.get_workflow(workflow_id, user_id)
if workflow:
set_current_org_id(workflow.organization_id)

View file

@ -1,76 +0,0 @@
# Pricing Module
This module contains pricing models and registries for different AI services used in workflow cost calculations.
## Structure
```
pricing/
├── __init__.py # Main module exports
├── models.py # Base pricing model classes
├── llm.py # LLM pricing configurations
├── tts.py # TTS pricing configurations
├── stt.py # STT pricing configurations
├── registry.py # Combined pricing registry
└── README.md # This file
```
## Pricing Models
### TokenPricingModel
Used for LLM services that charge based on tokens:
- `prompt_token_price`: Cost per prompt token
- `completion_token_price`: Cost per completion token
- `cache_read_discount`: Discount for cache read tokens (default 50%)
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
### CharacterPricingModel
Used for TTS services that charge based on character count:
- `character_price`: Cost per character
### TimePricingModel
Used for STT services that charge based on time:
- `second_price`: Cost per second
## Adding New Pricing
### Adding a New LLM Model
Edit `llm.py` and add the model to the appropriate provider:
```python
ServiceProviders.OPENAI: {
"new-model": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000,
completion_token_price=Decimal("8.00") / 1000000,
),
# ... existing models
}
```
### Adding a New Provider
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
2. The registry will automatically include them
### Adding a New Service Type
1. Create a new pricing file (e.g., `image.py`)
2. Define the pricing models
3. Import and add to `registry.py`
## Usage
The pricing registry is automatically imported and used by the cost calculator:
```python
from api.services.pricing import PRICING_REGISTRY
from api.services.workflow.cost_calculator import cost_calculator
# The cost calculator uses the pricing registry automatically
result = cost_calculator.calculate_total_cost(usage_info)
```
## Maintenance
- Update pricing when providers change their rates
- All prices should use `Decimal` for precision
- Include comments with current pricing from provider documentation
- Test changes with existing test suite

View file

@ -1,9 +0,0 @@
"""
Pricing module for workflow cost calculation.
This module contains pricing models and registries for different AI services.
"""
from .registry import PRICING_REGISTRY
__all__ = ["PRICING_REGISTRY"]

View file

@ -1,228 +0,0 @@
"""
Cost Calculator for Workflow Runs
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
Features:
- Token-based pricing for LLM services with cache optimization support
- Character-based pricing for TTS services
- Time-based pricing for STT services
- Configurable pricing models that can be updated
- Support for multiple providers and models
- Automatic provider inference from model names
- JSON serialization support for database storage
Usage:
from api.tasks.cost_calculator import cost_calculator
usage_info = {
"llm": {
"processor_name|||gpt-4o": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"total_tokens": 1500,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0
}
},
"tts": {
"processor_name|||aura-2-helena-en": 2000 # character count
}
}
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
print(f"Total cost: ${cost_breakdown['total']:.6f}")
"""
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
from api.services.configuration.registry import ServiceProviders
from api.services.pricing import PRICING_REGISTRY
from api.services.pricing.models import (
PricingModel,
)
class CostCalculator:
"""Main cost calculator class"""
def __init__(self, pricing_registry: Dict = None):
self.pricing_registry = pricing_registry or PRICING_REGISTRY
def get_pricing_model(
self, service_type: str, provider: str, model: str
) -> Optional[PricingModel]:
"""Get pricing model for a specific service, provider, and model"""
try:
service_pricing = self.pricing_registry.get(service_type, {})
# Try to get pricing for the specific provider
provider_pricing = service_pricing.get(provider, {})
pricing_model = provider_pricing.get(model) or provider_pricing.get(
"default"
)
if pricing_model:
return pricing_model
# If not found, try the "default" provider for this service type
default_provider_pricing = service_pricing.get("default", {})
return default_provider_pricing.get(model) or default_provider_pricing.get(
"default"
)
except (KeyError, AttributeError):
return None
def calculate_llm_cost(
self, provider: str, model: str, usage: Dict[str, int]
) -> Decimal:
"""Calculate cost for LLM usage"""
pricing_model = self.get_pricing_model("llm", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(usage)
def calculate_tts_cost(
self, provider: str, model: str, character_count: int
) -> Decimal:
"""Calculate cost for TTS usage"""
pricing_model = self.get_pricing_model("tts", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(character_count)
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
"""Calculate cost for STT usage"""
pricing_model = self.get_pricing_model("stt", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(seconds)
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
llm_cost_total = Decimal("0")
tts_cost_total = Decimal("0")
stt_cost_total = Decimal("0")
# Calculate LLM costs
llm_usage = usage_info.get("llm", {})
for key, usage in llm_usage.items():
processor, model = self._parse_key(key)
# Try to determine provider from processor name or model
provider = self._infer_provider_from_model(model, "llm")
cost = self.calculate_llm_cost(provider, model, usage)
llm_cost_total += cost
# Calculate TTS costs
tts_usage = usage_info.get("tts", {})
for key, character_count in tts_usage.items():
processor, model = self._parse_key(key)
# Handle the case where model is "None" - infer from processor
if model.lower() in ["none", "null", ""]:
provider = self._infer_provider_from_processor(processor, "tts")
model = "default" # Use default model for the provider
else:
provider = self._infer_provider_from_model(model, "tts")
cost = self.calculate_tts_cost(provider, model, character_count)
tts_cost_total += cost
# Calculate STT costs from explicit stt usage
stt_usage = usage_info.get("stt", {})
for key, seconds in stt_usage.items():
processor, model = self._parse_key(key)
provider = self._infer_provider_from_model(model, "stt")
cost = self.calculate_stt_cost(provider, model, seconds)
stt_cost_total += cost
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
return {
"llm_cost": float(llm_cost_total),
"tts_cost": float(tts_cost_total),
"stt_cost": float(stt_cost_total),
"total": float(total_cost),
}
def _parse_key(self, key) -> Tuple[str, str]:
"""Parse key which is in format 'processor|||model'"""
if isinstance(key, str) and "|||" in key:
parts = key.split("|||", 1)
return parts[0], parts[1]
else:
# Fallback for backwards compatibility or malformed keys
return str(key), "unknown"
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
"""Infer provider from model name"""
if not model:
return "unknown"
model_lower = model.lower()
# OpenAI models
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
return ServiceProviders.OPENAI
# Groq models
if any(keyword in model_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Elevenlabs models
if any(keyword in model_lower for keyword in ["eleven"]):
return ServiceProviders.ELEVENLABS
# Deepgram models
if any(
keyword in model_lower
for keyword in ["deepgram", "nova", "phonecall", "general"]
):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
"""Infer provider from processor name"""
if not processor:
return "unknown"
processor_lower = processor.lower()
# OpenAI processors
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
return ServiceProviders.OPENAI
# Groq processors
if any(keyword in processor_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Deepgram processors
if any(keyword in processor_lower for keyword in ["deepgram"]):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def update_pricing(
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
):
"""Update pricing for a specific service/provider/model combination"""
if service_type not in self.pricing_registry:
self.pricing_registry[service_type] = {}
if provider not in self.pricing_registry[service_type]:
self.pricing_registry[service_type][provider] = {}
self.pricing_registry[service_type][provider][model] = pricing_model
# Global cost calculator instance
cost_calculator = CostCalculator()

View file

@ -1,44 +0,0 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -1,143 +0,0 @@
"""
LLM pricing models for different providers.
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TokenPricingModel
# LLM pricing registry
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-3.5-turbo": TokenPricingModel(
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
),
"gpt-4": TokenPricingModel(
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
),
"gpt-4.1": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
),
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
),
"gpt-4.1-nano": TokenPricingModel(
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
),
"gpt-4.5-preview": TokenPricingModel(
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
),
"gpt-4o": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
completion_token_price=Decimal("10.00")
/ 1000000, # $10.00 per 1M tokens - FIXED
),
"gpt-4o-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
),
"gpt-4o-mini": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
),
"gpt-4o-search-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-search-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"o1": TokenPricingModel(
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
),
"o1-pro": TokenPricingModel(
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
),
"o1-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o3": TokenPricingModel(
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
),
"o3-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o4-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"computer-use-preview": TokenPricingModel(
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
),
"gpt-image-1": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
),
"codex-mini-latest": TokenPricingModel(
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
),
# Transcription models
"gpt-4o-transcribe": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-transcribe": TokenPricingModel(
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
),
# TTS models with token-based pricing
"gpt-4o-mini-tts": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("0")
/ 1000000, # No completion tokens for TTS
),
},
ServiceProviders.GROQ: {
"llama-3.3-70b-versatile": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
),
"deepseek-r1-distill-llama-70b": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
completion_token_price=Decimal("0.00079") / 1000,
),
},
ServiceProviders.AZURE: {
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("8.80")
/ 1000000, # $1.60 per 1M tokens if using data zone
)
},
}

View file

@ -1,89 +0,0 @@
"""
Base pricing models for different service types.
"""
from decimal import Decimal
from enum import Enum
from typing import Any, Dict
class CostType(Enum):
LLM_TOKENS = "llm_tokens"
TTS_CHARACTERS = "tts_characters"
STT_SECONDS = "stt_seconds"
class PricingModel:
"""Base class for pricing models"""
def calculate_cost(self, usage: Any) -> Decimal:
"""Calculate cost based on usage"""
raise NotImplementedError
class TokenPricingModel(PricingModel):
"""Pricing model for token-based services (LLM)"""
def __init__(
self,
prompt_token_price: Decimal,
completion_token_price: Decimal,
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
cache_creation_multiplier: Decimal = Decimal(
"1.25"
), # 25% premium for cache creation
):
self.prompt_token_price = prompt_token_price
self.completion_token_price = completion_token_price
self.cache_read_discount = cache_read_discount
self.cache_creation_multiplier = cache_creation_multiplier
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
"""Calculate cost for LLM token usage"""
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
# Base cost
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
completion_cost = Decimal(completion_tokens) * self.completion_token_price
# Cache adjustments
cache_read_savings = (
Decimal(cache_read_tokens)
* self.prompt_token_price
* self.cache_read_discount
)
cache_creation_premium = (
Decimal(cache_creation_tokens)
* self.prompt_token_price
* (self.cache_creation_multiplier - 1)
)
total_cost = (
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
)
return max(total_cost, Decimal("0")) # Ensure non-negative
class CharacterPricingModel(PricingModel):
"""Pricing model for character-based services (TTS)"""
def __init__(self, character_price: Decimal):
self.character_price = character_price
def calculate_cost(self, character_count: int) -> Decimal:
"""Calculate cost for TTS character usage"""
return Decimal(character_count) * self.character_price
class TimePricingModel(PricingModel):
"""Pricing model for time-based services (STT)"""
def __init__(self, second_price: Decimal):
self.second_price = second_price
def calculate_cost(self, seconds: float) -> Decimal:
"""Calculate cost for STT time usage"""
return Decimal(str(seconds)) * self.second_price

View file

@ -1,18 +0,0 @@
"""
Main pricing registry that combines all service type pricing models.
"""
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
# Combined pricing registry for all service types
PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -1,13 +0,0 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}

View file

@ -1,26 +0,0 @@
"""
STT (Speech-to-Text) pricing models for different providers.
Prices are per second for STT services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TimePricingModel
# STT pricing registry
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
ServiceProviders.DEEPGRAM: {
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
"default": TimePricingModel(Decimal("0.0077") / 60),
},
ServiceProviders.OPENAI: {
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
"default": TimePricingModel(Decimal("0.015") / 60),
},
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
}

View file

@ -1,30 +0,0 @@
"""
TTS (Text-to-Speech) pricing models for different providers.
Prices are per character for TTS services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import CharacterPricingModel
# TTS pricing registry
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
},
ServiceProviders.DEEPGRAM: {
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
},
ServiceProviders.ELEVENLABS: {
# 6400 usd per 250*1e6 characters
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
},
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
}

View file

@ -1,230 +0,0 @@
from decimal import Decimal
from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunMode
from api.services.pricing.cost_calculator import cost_calculator
from api.services.telephony.factory import get_telephony_provider_for_run
async def _fetch_telephony_cost(workflow_run) -> dict | None:
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
if (
workflow_run.mode
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
or not workflow_run.cost_info
):
return None
call_id = workflow_run.cost_info.get("call_id")
if not call_id:
logger.warning(f"call_id not found in cost_info")
return None
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning("Workflow not found for workflow run")
raise Exception("Workflow not found")
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
call_cost_info = await provider.get_call_cost(call_id)
if call_cost_info.get("status") == "error":
logger.error(
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
)
return None
cost_usd = call_cost_info.get("cost_usd", 0.0)
logger.info(
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
)
return {"cost_usd": cost_usd, "provider_name": provider_name}
async def _update_organization_usage(
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
) -> None:
"""Update organization usage after a workflow run."""
org_id = org.id
await db_client.update_usage_after_run(
org_id, dograh_tokens, duration_seconds, charge_usd
)
if charge_usd is not None:
logger.info(
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
)
else:
logger.info(
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
)
async def _get_pricing_organization(workflow_run):
workflow = getattr(workflow_run, "workflow", None)
organization_id = getattr(workflow, "organization_id", None)
if organization_id is None and workflow and workflow.user:
organization_id = workflow.user.selected_organization_id
if organization_id is None:
return None
return await db_client.get_organization_by_id(organization_id)
async def _build_usage_cost_snapshot(
usage_info: dict | None,
*,
workflow_run=None,
include_telephony_cost: bool = False,
organization=None,
calculated_at: str | None = None,
) -> dict | None:
if not usage_info:
logger.warning("No usage info available for workflow run")
return None
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
if include_telephony_cost and workflow_run is not None:
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
total_cost_usd = Decimal(str(cost_breakdown["total"]))
dograh_tokens = float(total_cost_usd * Decimal("100"))
if organization is None and workflow_run is not None:
organization = await _get_pricing_organization(workflow_run)
charge_usd = None
if organization and organization.price_per_second_usd:
duration_seconds = usage_info.get("call_duration_seconds", 0)
charge_usd = float(
Decimal(str(duration_seconds))
* Decimal(str(organization.price_per_second_usd))
)
cost_info = {
"cost_breakdown": cost_breakdown,
"total_cost_usd": float(total_cost_usd),
"dograh_token_usage": dograh_tokens,
"calculated_at": calculated_at
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
}
if charge_usd is not None:
cost_info["charge_usd"] = charge_usd
cost_info["price_per_second_usd"] = organization.price_per_second_usd
return cost_info
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
cost_info = await _build_usage_cost_snapshot(
workflow_run.usage_info,
workflow_run=workflow_run,
include_telephony_cost=True,
calculated_at=workflow_run.created_at.isoformat(),
)
if cost_info is None:
return None
return {
**(workflow_run.cost_info or {}),
**cost_info,
}
async def save_workflow_run_cost_info(
workflow_run_id: int, cost_info: dict | None
) -> None:
if cost_info is None:
return
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
async def apply_workflow_run_usage_to_organization(
workflow_run, cost_info: dict | None
) -> None:
if cost_info is None:
return
org = await _get_pricing_organization(workflow_run)
if not org:
return
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
async def apply_usage_delta_to_organization(
workflow_run, usage_info: dict | None
) -> dict | None:
org = await _get_pricing_organization(workflow_run)
if not org:
return None
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
if cost_info is None:
return None
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
return cost_info
async def calculate_workflow_run_cost(workflow_run_id: int):
logger.debug("Calculating cost for workflow run")
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning("Workflow run not found")
return
try:
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is None:
return
await save_workflow_run_cost_info(workflow_run_id, cost_info)
try:
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
except Exception as e:
org = await _get_pricing_organization(workflow_run)
if org:
logger.error(
f"Failed to update organization usage for org {org.id}: {e}"
)
else:
logger.error(f"Failed to update organization usage: {e}")
# Don't fail the whole cost calculation if usage update fails
logger.info(
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
)
except Exception as e:
logger.error(f"Error calculating cost for workflow run: {e}")
raise

View file

@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
for run in runs:
initial = run.initial_context or {}
gathered = run.gathered_context or {}
cost = run.cost_info or {}
usage = run.usage_info or {}
call_tags = gathered.get("call_tags", [])
if isinstance(call_tags, list):
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
run.created_at.isoformat() if run.created_at else "",
initial.get("phone_number", ""),
gathered.get("mapped_call_disposition", ""),
cost.get("call_duration_seconds", ""),
usage.get("call_duration_seconds", ""),
]
extracted = gathered.get("extracted_variables", {})

View file

@ -66,34 +66,6 @@ async def handle_vonage_events(
logger.error(f"[run {workflow_run_id}] Workflow run not found")
return {"status": "error", "message": "Workflow run not found"}
# For a completed call that includes cost info, capture it immediately
if event_data.get("status") == "completed":
# Vonage sometimes includes price info in the webhook
if "price" in event_data or "rate" in event_data:
try:
if workflow_run.cost_info:
# Store immediate cost info if available
cost_info = workflow_run.cost_info.copy()
if "price" in event_data:
cost_info["vonage_webhook_price"] = float(event_data["price"])
if "rate" in event_data:
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
if "duration" in event_data:
cost_info["vonage_webhook_duration"] = int(
event_data["duration"]
)
await db_client.update_workflow_run(
run_id=workflow_run_id, cost_info=cost_info
)
logger.info(
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
)
except Exception as e:
logger.error(
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
)
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:

View file

@ -0,0 +1,41 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}
def format_public_cost_info(
cost_info: dict | None, usage_info: dict | None
) -> dict | None:
"""Return the legacy response shape without doing local cost accounting."""
duration = None
if usage_info and usage_info.get("call_duration_seconds") is not None:
duration = int(round(usage_info.get("call_duration_seconds") or 0))
elif cost_info and cost_info.get("call_duration_seconds") is not None:
duration = int(round(cost_info.get("call_duration_seconds") or 0))
dograh_token_usage = 0
if cost_info:
if "dograh_token_usage" in cost_info:
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
elif "total_cost_usd" in cost_info:
dograh_token_usage = round(
float(cost_info.get("total_cost_usd", 0)) * 100, 2
)
if duration is None and dograh_token_usage == 0:
return None
return {
"dograh_token_usage": dograh_token_usage,
"call_duration_seconds": duration,
}

View file

@ -4,17 +4,11 @@ from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from api.db import db_client
from api.db.models import WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
)
from api.services.workflow.text_chat_logs import (
build_text_chat_realtime_feedback_events,
)
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
state=execution.state,
is_completed=execution.is_completed,
)
workflow_run = await db_client.get_workflow_run_by_id(run_id)
if workflow_run:
try:
# Apply the per-turn delta so org usage tracks cumulative run cost
# without replaying the full session totals on every turn.
await apply_usage_delta_to_organization(workflow_run, execution.usage)
except Exception as e:
logger.error(
f"Failed to update organization usage for text chat run {run_id}: {e}"
)
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is not None:
await db_client.update_workflow_run(run_id, cost_info=cost_info)
return await _reload_text_chat_session(run_id)

View file

@ -0,0 +1,111 @@
"""Workflow-run billing hooks.
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
For hosted deployments, Dograh reports completed platform usage to MPS.
When a server-minted MPS correlation id exists, MPS uses model-service usage
as the canonical duration. Otherwise Dograh reports the completed run duration.
"""
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.services.managed_model_services import get_mps_correlation_id
from api.services.mps_service_key_client import mps_service_key_client
def _workflow_run_organization_id(workflow_run) -> int | None:
workflow = getattr(workflow_run, "workflow", None)
return getattr(workflow, "organization_id", None)
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
duration = usage_info.get("call_duration_seconds")
try:
duration_seconds = float(duration)
except (TypeError, ValueError):
return None
return duration_seconds if duration_seconds > 0 else None
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
account = await mps_service_key_client.get_billing_account_status(
organization_id=organization_id
)
return bool(account and account.get("billing_mode") == "v2")
async def report_workflow_run_platform_usage(workflow_run) -> None:
"""Report hosted platform usage for a completed workflow run to MPS."""
if DEPLOYMENT_MODE == "oss":
return
if not getattr(workflow_run, "is_completed", False):
return
organization_id = _workflow_run_organization_id(workflow_run)
if organization_id is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no organization_id",
workflow_run.id,
)
return
correlation_id = get_mps_correlation_id(
getattr(workflow_run, "initial_context", None)
)
duration_seconds = (
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
)
if not correlation_id and duration_seconds is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no billable duration",
workflow_run.id,
)
return
try:
if not await _organization_uses_mps_billing_v2(organization_id):
return
result = await mps_service_key_client.report_platform_usage(
organization_id=organization_id,
correlation_id=correlation_id,
duration_seconds=duration_seconds,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": getattr(workflow_run, "workflow_id", None),
"duration_source": (
"mps_correlation" if correlation_id else "dograh_usage_info"
),
},
)
logger.info(
"Reported platform usage for workflow run {} to MPS: {}",
workflow_run.id,
result,
)
except Exception as e:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
"""Load a completed workflow run and report platform usage to MPS."""
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Skipping platform usage report: workflow run {} not found",
workflow_run_id,
)
return
await report_workflow_run_platform_usage(workflow_run)

View file

@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import (
)
from api.tasks.knowledge_base_processing import process_knowledge_base_document
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.tasks.s3_upload import (
process_workflow_completion,
upload_voicemail_audio_to_s3,
)
from api.tasks.s3_upload import upload_voicemail_audio_to_s3
from api.tasks.workflow_completion import process_workflow_completion
class WorkerSettings:

View file

@ -1,13 +1,9 @@
import os
from typing import Optional
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from api.db import db_client
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
from api.services.storage import get_current_storage_backend, storage_fs
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.services.storage import storage_fs
async def upload_voicemail_audio_to_s3(
@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3(
logger.warning(
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
)
async def process_workflow_completion(
_ctx,
workflow_run_id: int,
audio_temp_path: Optional[str] = None,
transcript_temp_path: Optional[str] = None,
):
"""Process workflow completion: upload artifacts and run integrations.
This task combines audio upload, transcript upload, and webhook integrations
into a single sequential task to ensure integrations run after uploads complete.
Args:
_ctx: ARQ context (unused)
workflow_run_id: The workflow run ID
audio_temp_path: Optional path to temp audio file
transcript_temp_path: Optional path to temp transcript file
"""
run_id = str(workflow_run_id)
set_current_run_id(run_id)
logger.info(f"Processing workflow completion for run {workflow_run_id}")
storage_backend = get_current_storage_backend()
# Step 1: Upload audio if provided
if audio_temp_path:
try:
if os.path.exists(audio_temp_path):
file_size = os.path.getsize(audio_temp_path)
logger.debug(f"Audio file size: {file_size} bytes")
recording_url = f"recordings/{workflow_run_id}.wav"
logger.info(
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(audio_temp_path, recording_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
recording_url=recording_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded audio: {recording_url}")
else:
logger.warning(f"Audio temp file not found: {audio_temp_path}")
except Exception as e:
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
finally:
if audio_temp_path and os.path.exists(audio_temp_path):
try:
os.remove(audio_temp_path)
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp audio file: {e}")
# Step 2: Upload transcript if provided
if transcript_temp_path:
try:
if os.path.exists(transcript_temp_path):
file_size = os.path.getsize(transcript_temp_path)
logger.debug(f"Transcript file size: {file_size} bytes")
transcript_url = f"transcripts/{workflow_run_id}.txt"
logger.info(
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
transcript_url=transcript_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded transcript: {transcript_url}")
else:
logger.warning(
f"Transcript temp file not found: {transcript_temp_path}"
)
except Exception as e:
logger.error(
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
)
finally:
if transcript_temp_path and os.path.exists(transcript_temp_path):
try:
os.remove(transcript_temp_path)
logger.debug(
f"Cleaned up temp transcript file: {transcript_temp_path}"
)
except Exception as e:
logger.warning(f"Failed to clean up temp transcript file: {e}")
# Step 3: Run integrations including QA analysis (after uploads are complete)
try:
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
except Exception as e:
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
# Step 4: Calculate cost after integrations (so QA token usage is included)
try:
await calculate_workflow_run_cost(workflow_run_id)
except Exception as e:
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")

View file

@ -0,0 +1,121 @@
import os
from typing import Optional
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from api.db import db_client
from api.services.storage import get_current_storage_backend, storage_fs
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
)
from api.tasks.run_integrations import run_integrations_post_workflow_run
async def process_workflow_completion(
_ctx,
workflow_run_id: int,
audio_temp_path: Optional[str] = None,
transcript_temp_path: Optional[str] = None,
):
"""Process workflow completion: upload artifacts and run integrations.
This task combines audio upload, transcript upload, and webhook integrations
into a single sequential task to ensure integrations run after uploads complete.
Args:
_ctx: ARQ context (unused)
workflow_run_id: The workflow run ID
audio_temp_path: Optional path to temp audio file
transcript_temp_path: Optional path to temp transcript file
"""
run_id = str(workflow_run_id)
set_current_run_id(run_id)
logger.info(f"Processing workflow completion for run {workflow_run_id}")
storage_backend = get_current_storage_backend()
# Step 1: Upload audio if provided
if audio_temp_path:
try:
if os.path.exists(audio_temp_path):
file_size = os.path.getsize(audio_temp_path)
logger.debug(f"Audio file size: {file_size} bytes")
recording_url = f"recordings/{workflow_run_id}.wav"
logger.info(
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(audio_temp_path, recording_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
recording_url=recording_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded audio: {recording_url}")
else:
logger.warning(f"Audio temp file not found: {audio_temp_path}")
except Exception as e:
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
finally:
if audio_temp_path and os.path.exists(audio_temp_path):
try:
os.remove(audio_temp_path)
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp audio file: {e}")
# Step 2: Upload transcript if provided
if transcript_temp_path:
try:
if os.path.exists(transcript_temp_path):
file_size = os.path.getsize(transcript_temp_path)
logger.debug(f"Transcript file size: {file_size} bytes")
transcript_url = f"transcripts/{workflow_run_id}.txt"
logger.info(
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
transcript_url=transcript_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded transcript: {transcript_url}")
else:
logger.warning(
f"Transcript temp file not found: {transcript_temp_path}"
)
except Exception as e:
logger.error(
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
)
finally:
if transcript_temp_path and os.path.exists(transcript_temp_path):
try:
os.remove(transcript_temp_path)
logger.debug(
f"Cleaned up temp transcript file: {transcript_temp_path}"
)
except Exception as e:
logger.warning(f"Failed to clean up temp transcript file: {e}")
# Step 3: Run integrations including QA analysis (after uploads are complete)
try:
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
except Exception as e:
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
# Step 4: Notify MPS after completion. MPS owns credit accounting.
try:
await report_completed_workflow_run_platform_usage(workflow_run_id)
except Exception as e:
logger.error(
f"Error reporting platform usage for workflow {workflow_run_id}: {e}"
)
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")

View file

@ -15,6 +15,7 @@ from api.services.configuration.ai_model_configuration import (
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
@ -22,6 +23,8 @@ from api.services.configuration.registry import (
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
GoogleLLMService,
GoogleRealtimeLLMConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
@ -93,6 +96,67 @@ def test_byok_v2_rejects_dograh_provider():
)
@pytest.mark.asyncio
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
config = OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "realtime",
"realtime": {
"realtime": {
"provider": "google_realtime",
"api_key": "google-realtime-key",
"model": "gemini-3.1-flash-live-preview",
"voice": "Puck",
"language": "en",
},
"llm": {
"provider": "google",
"api_key": "google-llm-key",
"model": "gemini-2.0-flash",
},
},
},
}
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is True
assert effective.stt is None
assert effective.tts is None
assert await UserConfigurationValidator().validate(effective) == {
"status": [{"model": "all", "message": "ok"}]
}
@pytest.mark.asyncio
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
effective = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key="google-llm-key",
model="gemini-2.0-flash",
),
realtime=GoogleRealtimeLLMConfiguration(
provider="google_realtime",
api_key="google-realtime-key",
model="gemini-3.1-flash-live-preview",
voice="Puck",
language="en",
),
is_realtime=False,
)
with pytest.raises(ValueError) as exc_info:
await UserConfigurationValidator().validate(effective)
assert exc_info.value.args[0] == [
{"model": "stt", "message": "API key is missing"},
{"model": "tts", "message": "API key is missing"},
]
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",

View file

@ -1,31 +0,0 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -128,3 +128,152 @@ async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
},
)
]
@pytest.mark.asyncio
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_billing_account_status(organization_id=42) == {
"organization_id": 42,
"billing_mode": "v2",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/status",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
correlation_id="mps-corr-123",
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"correlation_id": "mps-corr-123",
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
duration_seconds=87.0,
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"duration_seconds": 87.0,
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]

View file

@ -0,0 +1,33 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.routes import organization_usage
def test_is_mps_billing_v2_depends_only_on_account_mode():
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
assert organization_usage._is_mps_billing_v2(None) is False
@pytest.mark.asyncio
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
get_status = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_billing_account_status",
get_status,
)
user = SimpleNamespace(provider_id="provider-123")
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
"billing_mode": "v2"
}
get_status.assert_awaited_once_with(
organization_id=42,
created_by="provider-123",
)

View file

@ -1,4 +1,4 @@
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.workflow.run_usage_response import format_public_usage_info
def test_format_public_usage_info():

View file

@ -0,0 +1,212 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()

View file

@ -1,181 +0,0 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
calculate_workflow_run_cost,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
mode="textchat",
created_at=datetime.now(UTC),
usage_info={
"llm": {},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
},
cost_info={},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
cost_info = await build_workflow_run_cost_info(workflow_run)
assert cost_info is not None
assert cost_info["call_duration_seconds"] == 7
assert "cost_breakdown" in cost_info
assert "dograh_token_usage" in cost_info
assert cost_info["charge_usd"] == 10.5
update_usage.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
update_run = AsyncMock()
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
await calculate_workflow_run_cost(workflow_run.id)
update_run.assert_awaited_once()
saved_kwargs = update_run.await_args.kwargs
assert saved_kwargs["run_id"] == workflow_run.id
assert "cost_breakdown" in saved_kwargs["cost_info"]
update_usage.assert_awaited_once()
@pytest.mark.asyncio
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.cost_info = {"call_id": "preserve-me"}
usage_delta_one = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 1_000,
"completion_tokens": 100,
"total_tokens": 1_100,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 3,
}
usage_delta_two = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 2_000,
"completion_tokens": 50,
"total_tokens": 2_050,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 4,
}
merged_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 3_000,
"completion_tokens": 150,
"total_tokens": 3_150,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
}
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
second_delta = await apply_usage_delta_to_organization(
workflow_run, usage_delta_two
)
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
total_workflow_run.usage_info = merged_usage
total_cost = await build_workflow_run_cost_info(total_workflow_run)
assert first_delta is not None
assert second_delta is not None
assert total_cost is not None
assert update_usage.await_count == 2
assert update_usage.await_args_list[0].args == (
42,
first_delta["dograh_token_usage"],
3.0,
first_delta["charge_usd"],
)
assert update_usage.await_args_list[1].args == (
42,
second_delta["dograh_token_usage"],
4.0,
second_delta["charge_usd"],
)
assert (
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
) == pytest.approx(total_cost["dograh_token_usage"])
assert (
first_delta["charge_usd"] + second_delta["charge_usd"]
== total_cost["charge_usd"]
)

View file

@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Hello from the workflow tester."
]
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Welcome to the workflow tester.",

View file

@ -1,11 +1,13 @@
"use client";
import { CircleDollarSign, CreditCard, RefreshCw } from "lucide-react";
import Link from "next/link";
import { useCallback, useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost, getBillingCreditsApiV1OrganizationsBillingCreditsGet } from "@/client/sdk.gen";
import type { MpsBillingCreditsResponse } from "@/client/types.gen";
import type { MpsBillingCreditsResponse, MpsCreditLedgerEntryResponse } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Progress } from "@/components/ui/progress";
@ -19,7 +21,6 @@ import {
TableRow,
} from "@/components/ui/table";
import { useAppConfig } from "@/context/AppConfigContext";
import { useOrgConfig } from "@/context/OrgConfigContext";
import { useAuth } from "@/lib/auth";
const formatCredits = (value: number | null | undefined) => (
@ -50,18 +51,58 @@ const formatDate = (value: string) => (
})
);
const metricLabels: Record<string, string> = {
voice_minutes: "Voice usage",
platform_usage: "Platform usage",
};
const formatTitleCase = (value: string | null | undefined) => (
value ? value.replaceAll("_", " ").replace(/\b\w/g, (letter) => letter.toUpperCase()) : "-"
);
const getLedgerEntryLabel = (entry: MpsCreditLedgerEntryResponse) => {
if (entry.metric_code) {
return metricLabels[entry.metric_code] ?? formatTitleCase(entry.metric_code);
}
if (entry.entry_type === "grant") {
return "Credit grant";
}
if (entry.entry_type === "purchase") {
return "Credit purchase";
}
return formatTitleCase(entry.entry_type);
};
const formatBillableQuantity = (entry: MpsCreditLedgerEntryResponse) => {
if (entry.billable_quantity == null || !entry.quantity_unit) {
return null;
}
const unit = entry.quantity_unit === "minute" ? "min" : entry.quantity_unit;
return `${formatCredits(entry.billable_quantity)} ${unit}`;
};
const getRunHref = (entry: MpsCreditLedgerEntryResponse) => {
if (!entry.workflow_id || !entry.workflow_run_id) {
return null;
}
return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`;
};
export default function BillingPage() {
const auth = useAuth();
const { config } = useAppConfig();
const { orgContext, loading: orgConfigLoading } = useOrgConfig();
const [credits, setCredits] = useState<MpsBillingCreditsResponse | null>(null);
const [loading, setLoading] = useState(true);
const [refreshing, setRefreshing] = useState(false);
const [purchasing, setPurchasing] = useState(false);
const isManagedServiceV2 = Boolean(orgContext?.model_services.uses_managed_service_v2);
const isBillingV2 = isManagedServiceV2 && credits?.billing_version === "v2";
const canPurchaseCredits = isManagedServiceV2 && config?.deploymentMode !== "oss";
const isBillingV2 = credits?.billing_version === "v2";
const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss";
const totalQuota = credits?.total_quota ?? 0;
const remainingCredits = credits?.remaining_credits ?? 0;
const usedCredits = credits?.total_credits_used ?? 0;
@ -132,7 +173,7 @@ export default function BillingPage() {
}
};
if (loading || orgConfigLoading) {
if (loading) {
return (
<div className="container mx-auto p-6 space-y-6">
<div className="space-y-2">
@ -206,13 +247,14 @@ export default function BillingPage() {
</CardHeader>
<CardContent>
{ledgerEntries.length > 0 ? (
<div className="bg-card border rounded-lg overflow-hidden shadow-sm">
<div className="bg-card border rounded-lg overflow-x-auto shadow-sm">
<Table>
<TableHeader>
<TableRow className="bg-muted/50">
<TableHead>Date</TableHead>
<TableHead>Type</TableHead>
<TableHead>Activity</TableHead>
<TableHead>Origin</TableHead>
<TableHead>Run</TableHead>
<TableHead className="text-right">Delta</TableHead>
<TableHead className="text-right">Balance</TableHead>
<TableHead className="text-right">Amount</TableHead>
@ -221,11 +263,39 @@ export default function BillingPage() {
<TableBody>
{ledgerEntries.map((entry) => {
const delta = entry.credits_delta ?? 0;
const runHref = getRunHref(entry);
const billableQuantity = formatBillableQuantity(entry);
return (
<TableRow key={entry.id}>
<TableCell>{formatDate(entry.created_at)}</TableCell>
<TableCell className="capitalize">{entry.entry_type.replaceAll("_", " ")}</TableCell>
<TableCell>{entry.origin || "-"}</TableCell>
<TableCell>
<div className="flex flex-col gap-1">
<span className="font-medium">{getLedgerEntryLabel(entry)}</span>
{billableQuantity && (
<span className="text-xs text-muted-foreground">{billableQuantity}</span>
)}
</div>
</TableCell>
<TableCell>
{entry.origin ? (
<Badge variant="secondary">{formatTitleCase(entry.origin)}</Badge>
) : (
"-"
)}
</TableCell>
<TableCell>
{entry.workflow_run_id ? (
runHref ? (
<Link className="font-medium text-primary hover:underline" href={runHref}>
#{entry.workflow_run_id}
</Link>
) : (
<span>#{entry.workflow_run_id}</span>
)
) : (
"-"
)}
</TableCell>
<TableCell className={`text-right font-medium ${delta >= 0 ? "text-green-600" : "text-destructive"}`}>
{delta >= 0 ? "+" : ""}
{formatCredits(delta)}
@ -251,7 +321,6 @@ export default function BillingPage() {
<Card>
<CardHeader>
<CardTitle>Credit Usage</CardTitle>
<CardDescription>Current legacy MPS credit allocation.</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<Progress value={usagePercent} />

View file

@ -3176,6 +3176,38 @@ export type MpsCreditLedgerEntryResponse = {
* Payment Order Id
*/
payment_order_id?: number | null;
/**
* Metric Code
*/
metric_code?: string | null;
/**
* Correlation Id
*/
correlation_id?: string | null;
/**
* Aggregation Key
*/
aggregation_key?: string | null;
/**
* Usage Event Id
*/
usage_event_id?: number | null;
/**
* Workflow Run Id
*/
workflow_run_id?: number | null;
/**
* Workflow Id
*/
workflow_id?: number | null;
/**
* Billable Quantity
*/
billable_quantity?: number | null;
/**
* Quantity Unit
*/
quantity_unit?: string | null;
/**
* Metadata
*/

View file

@ -2,18 +2,33 @@
* Extract a human-readable message from a backend error response.
*
* The generated API client returns `{ error }` on failure (it does not throw),
* and FastAPI shapes that error as either `{ detail: string }` (HTTPException)
* or `{ detail: [{ msg, loc, ... }] }` (422 validation). This normalizes both
* to a single string so it can be rendered or thrown directly never pass the
* raw `detail` to React, as the 422 array crashes rendering.
* and FastAPI shapes that error as `{ detail: string }`, `{ detail:
* [{ msg, loc, ... }] }`, or backend validation arrays like `{ detail:
* [{ model, message }] }`. This normalizes those to a single string so it can
* be rendered or thrown directly.
*/
export function detailFromError(err: unknown, fallback = "Request failed"): string {
if (typeof err === "string") return err;
const e = err as { detail?: unknown };
if (typeof e?.detail === "string") return e.detail;
if (Array.isArray(e?.detail) && e.detail.length > 0) {
const first = e.detail[0] as { msg?: string };
if (first?.msg) return first.msg;
const messages = e.detail
.map((item) => {
if (typeof item === "string") return item;
if (!item || typeof item !== "object") return null;
const detail = item as { message?: unknown; msg?: unknown; model?: unknown };
const message = typeof detail.message === "string"
? detail.message
: typeof detail.msg === "string"
? detail.msg
: null;
if (!message) return null;
return typeof detail.model === "string" && detail.model
? `${detail.model}: ${message}`
: message;
})
.filter((message): message is string => Boolean(message));
if (messages.length > 0) return messages.join("\n");
}
return fallback;
}