feat: billing and credit management v2 (#429)

* feat: use mps generated correlation ID

* chore: update pipecat submodule

* feat: add credit purchase URL

* feat: carve out billing page and show credit ledger

* feat: deprecate dograh based quota tracking

* fix: remove cost calculation from dograh codebase

* fix: create mps account on migrate to v2

* chore: update pipecat
This commit is contained in:
Abhishek 2026-06-12 14:55:30 +05:30 committed by GitHub
parent 97d7103480
commit 1f1149f4d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 3335 additions and 2057 deletions

View file

@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.posthog_client import capture_event
from api.utils.auth import decode_jwt_token
@ -110,6 +111,19 @@ async def get_user(
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
try:
await ensure_hosted_mps_billing_account_v2(
organization.id,
created_by=str(stack_user["id"]),
)
except Exception:
logger.warning(
"Failed to initialize hosted MPS billing account for "
"organization {}",
organization.id,
exc_info=True,
)
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users
@ -285,8 +299,8 @@ async def create_user_configuration_with_mps_key(
"model": "default",
},
}
user_config = EffectiveAIModelConfiguration(**configuration)
return user_config
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"

View file

@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,

View file

@ -8,7 +8,7 @@ from groq import Groq
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
from api.schemas.ai_model_configuration import (
EffectiveAIModelConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
@ -75,21 +75,21 @@ class UserConfigurationValidator:
status_list = []
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
# Realtime is optional - only validate if is_realtime is enabled
if configuration.is_realtime:
status_list.extend(
self._validate_service(
configuration.realtime, "realtime", required=True
)
)
else:
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)

View file

@ -12,7 +12,7 @@ The rules are simple:
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceConfig
from api.services.integrations import get_node_secret_fields

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
import copy
from typing import Dict
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import copy
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
REGISTRY,
ServiceType,

View file

@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the OpenAI embedding service.
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
field_name="base_url",
)
client_kwargs["base_url"] = base_url
if default_headers:
client_kwargs["default_headers"] = default_headers
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:

View file

@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any
from loguru import logger
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
def uses_managed_model_services_v2(
ai_model_config: EffectiveAIModelConfiguration | None,
) -> bool:
if (
ai_model_config is None
or getattr(ai_model_config, "managed_service_version", None) != 2
):
return False
return any(
_is_dograh_service(getattr(ai_model_config, section_name, None))
for section_name in ("llm", "tts", "stt", "embeddings")
)
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
if not initial_context:
return None
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
if correlation_id is None:
return None
return str(correlation_id)
async def ensure_mps_correlation_id(
*,
ai_model_config: EffectiveAIModelConfiguration,
workflow_run_id: int,
initial_context: dict[str, Any] | None,
) -> str | None:
existing = get_mps_correlation_id(initial_context)
if existing:
return existing
if not uses_managed_model_services_v2(ai_model_config):
return None
service_key = _get_dograh_service_api_key(ai_model_config)
if not service_key:
raise ValueError(
"Managed model services v2 requires a Dograh service key before the run starts."
)
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
correlation_id = response.get("correlation_id")
if not correlation_id:
raise ValueError("MPS correlation-id response did not include correlation_id")
correlation_id = str(correlation_id)
logger.info(
"Minted MPS correlation id {} for workflow run {}",
correlation_id,
workflow_run_id,
)
return correlation_id
def _is_dograh_service(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):
service = getattr(ai_model_config, section_name, None)
if not _is_dograh_service(service):
continue
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if keys:
return keys[0]
api_key = getattr(service, "api_key", None)
if isinstance(api_key, str) and api_key:
return api_key
return None

View file

@ -0,0 +1,23 @@
from typing import Optional
from api.constants import DEPLOYMENT_MODE
from api.services.mps_service_key_client import mps_service_key_client
async def ensure_hosted_mps_billing_account_v2(
organization_id: int,
*,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Ensure hosted orgs have an MPS billing v2 account.
OSS deployments use legacy per-key quota accounting and do not create MPS
billing accounts.
"""
if DEPLOYMENT_MODE == "oss":
return None
return await mps_service_key_client.ensure_billing_account_v2(
organization_id=organization_id,
created_by=created_by,
)

View file

@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
Service keys are stored and managed entirely in MPS, not in the local database.
"""
import asyncio
from typing import List, Optional
import httpx
@ -353,6 +354,234 @@ class MPSServiceKeyClient:
response=response,
)
async def create_credit_purchase_url(
self,
organization_id: int,
created_by: Optional[str] = None,
return_url: Optional[str] = None,
billing_details: Optional[dict] = None,
) -> dict:
"""Create a short-lived MPS checkout URL for adding organization credits."""
payload = {
"created_by": created_by,
"return_url": return_url,
"billing_details": billing_details or {},
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create MPS credit purchase URL: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create MPS credit purchase URL: {response.text}",
request=response.request,
response=response,
)
async def get_credit_ledger(
self,
organization_id: int,
page: int = 1,
limit: int = 50,
created_by: Optional[str] = None,
) -> dict:
"""Get the MPS v2 billing account balance and recent credit ledger."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
params={"page": page, "limit": limit},
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS credit ledger: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS credit ledger: {response.text}",
request=response.request,
response=response,
)
async def get_billing_account_status(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Get an existing MPS v2 billing account without creating one."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS billing account status: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS billing account status: {response.text}",
request=response.request,
response=response,
)
async def ensure_billing_account_v2(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Create or return the MPS v2 billing account for an organization."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to ensure MPS billing account v2: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to ensure MPS billing account v2: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,
service_key: str,
workflow_run_id: int | None = None,
) -> dict:
"""Mint a server-generated correlation ID for managed model services."""
payload: dict[str, int] = {}
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
json=payload,
headers={
"Authorization": f"Bearer {service_key}",
"Content-Type": "application/json",
},
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create correlation ID: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create correlation ID: {response.text}",
request=response.request,
response=response,
)
async def report_platform_usage(
self,
*,
organization_id: int,
correlation_id: Optional[str] = None,
duration_seconds: Optional[float] = None,
workflow_run_id: int | None = None,
metadata: Optional[dict] = None,
max_attempts: int = 3,
) -> dict:
"""Report hosted Dograh platform usage for a completed workflow run."""
if DEPLOYMENT_MODE == "oss":
raise ValueError("OSS deployments must not report platform usage to MPS")
if not correlation_id and duration_seconds is None:
raise ValueError(
"Platform usage reports require correlation_id or duration_seconds"
)
payload: dict = {
"metadata": metadata or {},
}
if correlation_id:
payload["correlation_id"] = correlation_id
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
max_attempts = max(1, max_attempts)
last_response: httpx.Response | None = None
async with httpx.AsyncClient(timeout=self.timeout) as client:
for attempt in range(1, max_attempts + 1):
response = await client.post(
(
f"{self.base_url}/api/v1/billing/accounts/"
f"{organization_id}/platform-usage"
),
json=payload,
headers=self._get_headers(organization_id=organization_id),
)
last_response = response
if response.status_code == 200:
return response.json()
should_retry = (
response.status_code == 409
and "usage_not_ready" in response.text
and attempt < max_attempts
)
if should_retry:
await asyncio.sleep(attempt)
continue
logger.error(
"Failed to report platform usage: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to report platform usage: {response.text}",
request=response.request,
response=response,
)
raise httpx.HTTPStatusError(
"Failed to report platform usage",
request=last_response.request,
response=last_response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -0,0 +1,50 @@
from typing import Literal, Optional
from pydantic import BaseModel
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
class OrganizationModelServicesContext(BaseModel):
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
has_model_configuration_v2: bool
managed_service_version: Optional[int] = None
uses_managed_service_v2: bool
class OrganizationContextResponse(BaseModel):
organization_id: Optional[int] = None
organization_provider_id: Optional[str] = None
model_services: OrganizationModelServicesContext
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
organization_id = user.selected_organization_id
organization = (
await db_client.get_organization_by_id(organization_id)
if organization_id
else None
)
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=organization_id,
)
managed_service_version = resolved.effective.managed_service_version
return OrganizationContextResponse(
organization_id=organization_id,
organization_provider_id=organization.provider_id if organization else None,
model_services=OrganizationModelServicesContext(
config_source=resolved.source,
has_model_configuration_v2=resolved.source == "organization_v2",
managed_service_version=managed_service_version,
uses_managed_service_v2=(
resolved.source == "organization_v2" and managed_service_version == 2
),
),
)

View file

@ -162,15 +162,13 @@ async def run_pipeline_telephony(
workflow_id: Workflow being executed.
workflow_run_id: Workflow run row.
user_id: Owner of the workflow.
call_id: Provider call identifier (stored in cost_info for billing).
call_id: Provider call identifier.
transport_kwargs: Provider-specific kwargs forwarded to the transport
factory (e.g. stream_sid + call_sid for Twilio).
"""
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
set_current_run_id(workflow_run_id)
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
workflow = await db_client.get_workflow(workflow_id, user_id)
if workflow:
set_current_org_id(workflow.organization_id)
@ -340,7 +338,7 @@ async def _run_pipeline(
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
merged_call_context_vars = dict(workflow_run.initial_context or {})
# If there is some extra call_context_vars, fold them in. Persistence
# happens once below, after runtime_configuration is also resolved.
if call_context_vars:
@ -398,6 +396,19 @@ async def _run_pipeline(
else:
user_config = resolved_user_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=merged_call_context_vars,
)
if mps_correlation_id:
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
is_realtime = user_config.is_realtime and user_config.realtime is not None
@ -409,11 +420,23 @@ async def _run_pipeline(
# Realtime services don't implement run_inference, so create a
# separate text LLM for variable extraction and other out-of-band
# inference calls.
inference_llm = create_llm_service(user_config)
inference_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
stt = create_stt_service(
user_config,
audio_config,
keyterms=keyterms,
correlation_id=mps_correlation_id,
)
tts = create_tts_service(
user_config,
audio_config,
correlation_id=mps_correlation_id,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = None
# Stamp the providers/models actually resolved for this run onto
@ -695,7 +718,10 @@ async def _run_pipeline(
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
voicemail_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),

View file

@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
user_config,
audio_config: "AudioConfig",
keyterms: list[str] | None = None,
correlation_id: str | None = None,
):
"""Create and return appropriate STT service based on user configuration
@ -160,6 +163,7 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
correlation_id=correlation_id,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
@ -286,7 +290,9 @@ def create_stt_service(
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_tts_service(
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
):
"""Create and return appropriate TTS service based on user configuration
Args:
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
correlation_id=correlation_id,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
model: str,
api_key: str | None,
*,
correlation_id: str | None = None,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=api_key,
correlation_id=correlation_id,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
def create_llm_service(user_config, correlation_id: str | None = None):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
@ -880,4 +889,10 @@ def create_llm_service(user_config):
elif provider == ServiceProviders.SARVAM.value:
kwargs["temperature"] = user_config.llm.temperature
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
return create_llm_service_from_provider(
provider,
model,
api_key,
correlation_id=correlation_id,
**kwargs,
)

View file

@ -1,76 +0,0 @@
# Pricing Module
This module contains pricing models and registries for different AI services used in workflow cost calculations.
## Structure
```
pricing/
├── __init__.py # Main module exports
├── models.py # Base pricing model classes
├── llm.py # LLM pricing configurations
├── tts.py # TTS pricing configurations
├── stt.py # STT pricing configurations
├── registry.py # Combined pricing registry
└── README.md # This file
```
## Pricing Models
### TokenPricingModel
Used for LLM services that charge based on tokens:
- `prompt_token_price`: Cost per prompt token
- `completion_token_price`: Cost per completion token
- `cache_read_discount`: Discount for cache read tokens (default 50%)
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
### CharacterPricingModel
Used for TTS services that charge based on character count:
- `character_price`: Cost per character
### TimePricingModel
Used for STT services that charge based on time:
- `second_price`: Cost per second
## Adding New Pricing
### Adding a New LLM Model
Edit `llm.py` and add the model to the appropriate provider:
```python
ServiceProviders.OPENAI: {
"new-model": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000,
completion_token_price=Decimal("8.00") / 1000000,
),
# ... existing models
}
```
### Adding a New Provider
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
2. The registry will automatically include them
### Adding a New Service Type
1. Create a new pricing file (e.g., `image.py`)
2. Define the pricing models
3. Import and add to `registry.py`
## Usage
The pricing registry is automatically imported and used by the cost calculator:
```python
from api.services.pricing import PRICING_REGISTRY
from api.services.workflow.cost_calculator import cost_calculator
# The cost calculator uses the pricing registry automatically
result = cost_calculator.calculate_total_cost(usage_info)
```
## Maintenance
- Update pricing when providers change their rates
- All prices should use `Decimal` for precision
- Include comments with current pricing from provider documentation
- Test changes with existing test suite

View file

@ -1,9 +0,0 @@
"""
Pricing module for workflow cost calculation.
This module contains pricing models and registries for different AI services.
"""
from .registry import PRICING_REGISTRY
__all__ = ["PRICING_REGISTRY"]

View file

@ -1,228 +0,0 @@
"""
Cost Calculator for Workflow Runs
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
Features:
- Token-based pricing for LLM services with cache optimization support
- Character-based pricing for TTS services
- Time-based pricing for STT services
- Configurable pricing models that can be updated
- Support for multiple providers and models
- Automatic provider inference from model names
- JSON serialization support for database storage
Usage:
from api.tasks.cost_calculator import cost_calculator
usage_info = {
"llm": {
"processor_name|||gpt-4o": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"total_tokens": 1500,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0
}
},
"tts": {
"processor_name|||aura-2-helena-en": 2000 # character count
}
}
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
print(f"Total cost: ${cost_breakdown['total']:.6f}")
"""
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
from api.services.configuration.registry import ServiceProviders
from api.services.pricing import PRICING_REGISTRY
from api.services.pricing.models import (
PricingModel,
)
class CostCalculator:
"""Main cost calculator class"""
def __init__(self, pricing_registry: Dict = None):
self.pricing_registry = pricing_registry or PRICING_REGISTRY
def get_pricing_model(
self, service_type: str, provider: str, model: str
) -> Optional[PricingModel]:
"""Get pricing model for a specific service, provider, and model"""
try:
service_pricing = self.pricing_registry.get(service_type, {})
# Try to get pricing for the specific provider
provider_pricing = service_pricing.get(provider, {})
pricing_model = provider_pricing.get(model) or provider_pricing.get(
"default"
)
if pricing_model:
return pricing_model
# If not found, try the "default" provider for this service type
default_provider_pricing = service_pricing.get("default", {})
return default_provider_pricing.get(model) or default_provider_pricing.get(
"default"
)
except (KeyError, AttributeError):
return None
def calculate_llm_cost(
self, provider: str, model: str, usage: Dict[str, int]
) -> Decimal:
"""Calculate cost for LLM usage"""
pricing_model = self.get_pricing_model("llm", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(usage)
def calculate_tts_cost(
self, provider: str, model: str, character_count: int
) -> Decimal:
"""Calculate cost for TTS usage"""
pricing_model = self.get_pricing_model("tts", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(character_count)
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
"""Calculate cost for STT usage"""
pricing_model = self.get_pricing_model("stt", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(seconds)
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
llm_cost_total = Decimal("0")
tts_cost_total = Decimal("0")
stt_cost_total = Decimal("0")
# Calculate LLM costs
llm_usage = usage_info.get("llm", {})
for key, usage in llm_usage.items():
processor, model = self._parse_key(key)
# Try to determine provider from processor name or model
provider = self._infer_provider_from_model(model, "llm")
cost = self.calculate_llm_cost(provider, model, usage)
llm_cost_total += cost
# Calculate TTS costs
tts_usage = usage_info.get("tts", {})
for key, character_count in tts_usage.items():
processor, model = self._parse_key(key)
# Handle the case where model is "None" - infer from processor
if model.lower() in ["none", "null", ""]:
provider = self._infer_provider_from_processor(processor, "tts")
model = "default" # Use default model for the provider
else:
provider = self._infer_provider_from_model(model, "tts")
cost = self.calculate_tts_cost(provider, model, character_count)
tts_cost_total += cost
# Calculate STT costs from explicit stt usage
stt_usage = usage_info.get("stt", {})
for key, seconds in stt_usage.items():
processor, model = self._parse_key(key)
provider = self._infer_provider_from_model(model, "stt")
cost = self.calculate_stt_cost(provider, model, seconds)
stt_cost_total += cost
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
return {
"llm_cost": float(llm_cost_total),
"tts_cost": float(tts_cost_total),
"stt_cost": float(stt_cost_total),
"total": float(total_cost),
}
def _parse_key(self, key) -> Tuple[str, str]:
"""Parse key which is in format 'processor|||model'"""
if isinstance(key, str) and "|||" in key:
parts = key.split("|||", 1)
return parts[0], parts[1]
else:
# Fallback for backwards compatibility or malformed keys
return str(key), "unknown"
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
"""Infer provider from model name"""
if not model:
return "unknown"
model_lower = model.lower()
# OpenAI models
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
return ServiceProviders.OPENAI
# Groq models
if any(keyword in model_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Elevenlabs models
if any(keyword in model_lower for keyword in ["eleven"]):
return ServiceProviders.ELEVENLABS
# Deepgram models
if any(
keyword in model_lower
for keyword in ["deepgram", "nova", "phonecall", "general"]
):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
"""Infer provider from processor name"""
if not processor:
return "unknown"
processor_lower = processor.lower()
# OpenAI processors
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
return ServiceProviders.OPENAI
# Groq processors
if any(keyword in processor_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Deepgram processors
if any(keyword in processor_lower for keyword in ["deepgram"]):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def update_pricing(
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
):
"""Update pricing for a specific service/provider/model combination"""
if service_type not in self.pricing_registry:
self.pricing_registry[service_type] = {}
if provider not in self.pricing_registry[service_type]:
self.pricing_registry[service_type][provider] = {}
self.pricing_registry[service_type][provider][model] = pricing_model
# Global cost calculator instance
cost_calculator = CostCalculator()

View file

@ -1,44 +0,0 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -1,143 +0,0 @@
"""
LLM pricing models for different providers.
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TokenPricingModel
# LLM pricing registry
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-3.5-turbo": TokenPricingModel(
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
),
"gpt-4": TokenPricingModel(
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
),
"gpt-4.1": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
),
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
),
"gpt-4.1-nano": TokenPricingModel(
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
),
"gpt-4.5-preview": TokenPricingModel(
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
),
"gpt-4o": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
completion_token_price=Decimal("10.00")
/ 1000000, # $10.00 per 1M tokens - FIXED
),
"gpt-4o-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
),
"gpt-4o-mini": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
),
"gpt-4o-search-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-search-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"o1": TokenPricingModel(
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
),
"o1-pro": TokenPricingModel(
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
),
"o1-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o3": TokenPricingModel(
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
),
"o3-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o4-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"computer-use-preview": TokenPricingModel(
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
),
"gpt-image-1": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
),
"codex-mini-latest": TokenPricingModel(
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
),
# Transcription models
"gpt-4o-transcribe": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-transcribe": TokenPricingModel(
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
),
# TTS models with token-based pricing
"gpt-4o-mini-tts": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("0")
/ 1000000, # No completion tokens for TTS
),
},
ServiceProviders.GROQ: {
"llama-3.3-70b-versatile": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
),
"deepseek-r1-distill-llama-70b": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
completion_token_price=Decimal("0.00079") / 1000,
),
},
ServiceProviders.AZURE: {
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("8.80")
/ 1000000, # $1.60 per 1M tokens if using data zone
)
},
}

View file

@ -1,89 +0,0 @@
"""
Base pricing models for different service types.
"""
from decimal import Decimal
from enum import Enum
from typing import Any, Dict
class CostType(Enum):
LLM_TOKENS = "llm_tokens"
TTS_CHARACTERS = "tts_characters"
STT_SECONDS = "stt_seconds"
class PricingModel:
"""Base class for pricing models"""
def calculate_cost(self, usage: Any) -> Decimal:
"""Calculate cost based on usage"""
raise NotImplementedError
class TokenPricingModel(PricingModel):
"""Pricing model for token-based services (LLM)"""
def __init__(
self,
prompt_token_price: Decimal,
completion_token_price: Decimal,
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
cache_creation_multiplier: Decimal = Decimal(
"1.25"
), # 25% premium for cache creation
):
self.prompt_token_price = prompt_token_price
self.completion_token_price = completion_token_price
self.cache_read_discount = cache_read_discount
self.cache_creation_multiplier = cache_creation_multiplier
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
"""Calculate cost for LLM token usage"""
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
# Base cost
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
completion_cost = Decimal(completion_tokens) * self.completion_token_price
# Cache adjustments
cache_read_savings = (
Decimal(cache_read_tokens)
* self.prompt_token_price
* self.cache_read_discount
)
cache_creation_premium = (
Decimal(cache_creation_tokens)
* self.prompt_token_price
* (self.cache_creation_multiplier - 1)
)
total_cost = (
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
)
return max(total_cost, Decimal("0")) # Ensure non-negative
class CharacterPricingModel(PricingModel):
"""Pricing model for character-based services (TTS)"""
def __init__(self, character_price: Decimal):
self.character_price = character_price
def calculate_cost(self, character_count: int) -> Decimal:
"""Calculate cost for TTS character usage"""
return Decimal(character_count) * self.character_price
class TimePricingModel(PricingModel):
"""Pricing model for time-based services (STT)"""
def __init__(self, second_price: Decimal):
self.second_price = second_price
def calculate_cost(self, seconds: float) -> Decimal:
"""Calculate cost for STT time usage"""
return Decimal(str(seconds)) * self.second_price

View file

@ -1,18 +0,0 @@
"""
Main pricing registry that combines all service type pricing models.
"""
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
# Combined pricing registry for all service types
PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -1,13 +0,0 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}

View file

@ -1,26 +0,0 @@
"""
STT (Speech-to-Text) pricing models for different providers.
Prices are per second for STT services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TimePricingModel
# STT pricing registry
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
ServiceProviders.DEEPGRAM: {
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
"default": TimePricingModel(Decimal("0.0077") / 60),
},
ServiceProviders.OPENAI: {
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
"default": TimePricingModel(Decimal("0.015") / 60),
},
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
}

View file

@ -1,30 +0,0 @@
"""
TTS (Text-to-Speech) pricing models for different providers.
Prices are per character for TTS services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import CharacterPricingModel
# TTS pricing registry
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
},
ServiceProviders.DEEPGRAM: {
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
},
ServiceProviders.ELEVENLABS: {
# 6400 usd per 250*1e6 characters
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
},
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
}

View file

@ -1,230 +0,0 @@
from decimal import Decimal
from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunMode
from api.services.pricing.cost_calculator import cost_calculator
from api.services.telephony.factory import get_telephony_provider_for_run
async def _fetch_telephony_cost(workflow_run) -> dict | None:
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
if (
workflow_run.mode
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
or not workflow_run.cost_info
):
return None
call_id = workflow_run.cost_info.get("call_id")
if not call_id:
logger.warning(f"call_id not found in cost_info")
return None
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning("Workflow not found for workflow run")
raise Exception("Workflow not found")
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
call_cost_info = await provider.get_call_cost(call_id)
if call_cost_info.get("status") == "error":
logger.error(
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
)
return None
cost_usd = call_cost_info.get("cost_usd", 0.0)
logger.info(
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
)
return {"cost_usd": cost_usd, "provider_name": provider_name}
async def _update_organization_usage(
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
) -> None:
"""Update organization usage after a workflow run."""
org_id = org.id
await db_client.update_usage_after_run(
org_id, dograh_tokens, duration_seconds, charge_usd
)
if charge_usd is not None:
logger.info(
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
)
else:
logger.info(
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
)
async def _get_pricing_organization(workflow_run):
workflow = getattr(workflow_run, "workflow", None)
organization_id = getattr(workflow, "organization_id", None)
if organization_id is None and workflow and workflow.user:
organization_id = workflow.user.selected_organization_id
if organization_id is None:
return None
return await db_client.get_organization_by_id(organization_id)
async def _build_usage_cost_snapshot(
usage_info: dict | None,
*,
workflow_run=None,
include_telephony_cost: bool = False,
organization=None,
calculated_at: str | None = None,
) -> dict | None:
if not usage_info:
logger.warning("No usage info available for workflow run")
return None
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
if include_telephony_cost and workflow_run is not None:
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
total_cost_usd = Decimal(str(cost_breakdown["total"]))
dograh_tokens = float(total_cost_usd * Decimal("100"))
if organization is None and workflow_run is not None:
organization = await _get_pricing_organization(workflow_run)
charge_usd = None
if organization and organization.price_per_second_usd:
duration_seconds = usage_info.get("call_duration_seconds", 0)
charge_usd = float(
Decimal(str(duration_seconds))
* Decimal(str(organization.price_per_second_usd))
)
cost_info = {
"cost_breakdown": cost_breakdown,
"total_cost_usd": float(total_cost_usd),
"dograh_token_usage": dograh_tokens,
"calculated_at": calculated_at
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
}
if charge_usd is not None:
cost_info["charge_usd"] = charge_usd
cost_info["price_per_second_usd"] = organization.price_per_second_usd
return cost_info
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
cost_info = await _build_usage_cost_snapshot(
workflow_run.usage_info,
workflow_run=workflow_run,
include_telephony_cost=True,
calculated_at=workflow_run.created_at.isoformat(),
)
if cost_info is None:
return None
return {
**(workflow_run.cost_info or {}),
**cost_info,
}
async def save_workflow_run_cost_info(
workflow_run_id: int, cost_info: dict | None
) -> None:
if cost_info is None:
return
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
async def apply_workflow_run_usage_to_organization(
workflow_run, cost_info: dict | None
) -> None:
if cost_info is None:
return
org = await _get_pricing_organization(workflow_run)
if not org:
return
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
async def apply_usage_delta_to_organization(
workflow_run, usage_info: dict | None
) -> dict | None:
org = await _get_pricing_organization(workflow_run)
if not org:
return None
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
if cost_info is None:
return None
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
return cost_info
async def calculate_workflow_run_cost(workflow_run_id: int):
logger.debug("Calculating cost for workflow run")
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning("Workflow run not found")
return
try:
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is None:
return
await save_workflow_run_cost_info(workflow_run_id, cost_info)
try:
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
except Exception as e:
org = await _get_pricing_organization(workflow_run)
if org:
logger.error(
f"Failed to update organization usage for org {org.id}: {e}"
)
else:
logger.error(f"Failed to update organization usage: {e}")
# Don't fail the whole cost calculation if usage update fails
logger.info(
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
)
except Exception as e:
logger.error(f"Error calculating cost for workflow run: {e}")
raise

View file

@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
for run in runs:
initial = run.initial_context or {}
gathered = run.gathered_context or {}
cost = run.cost_info or {}
usage = run.usage_info or {}
call_tags = gathered.get("call_tags", [])
if isinstance(call_tags, list):
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
run.created_at.isoformat() if run.created_at else "",
initial.get("phone_number", ""),
gathered.get("mapped_call_disposition", ""),
cost.get("call_duration_seconds", ""),
usage.get("call_duration_seconds", ""),
]
extracted = gathered.get("extracted_variables", {})

View file

@ -66,34 +66,6 @@ async def handle_vonage_events(
logger.error(f"[run {workflow_run_id}] Workflow run not found")
return {"status": "error", "message": "Workflow run not found"}
# For a completed call that includes cost info, capture it immediately
if event_data.get("status") == "completed":
# Vonage sometimes includes price info in the webhook
if "price" in event_data or "rate" in event_data:
try:
if workflow_run.cost_info:
# Store immediate cost info if available
cost_info = workflow_run.cost_info.copy()
if "price" in event_data:
cost_info["vonage_webhook_price"] = float(event_data["price"])
if "rate" in event_data:
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
if "duration" in event_data:
cost_info["vonage_webhook_duration"] = int(
event_data["duration"]
)
await db_client.update_workflow_run(
run_id=workflow_run_id, cost_info=cost_info
)
logger.info(
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
)
except Exception as e:
logger.error(
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
)
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:

View file

@ -35,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
@ -382,6 +383,9 @@ class PipecatEngine:
embeddings_provider=self._embeddings_provider,
embeddings_endpoint=self._embeddings_endpoint,
embeddings_api_version=self._embeddings_api_version,
correlation_id=self._call_context_vars.get(
MPS_CORRELATION_ID_CONTEXT_KEY
),
tracing_context=self._get_otel_context(),
)

View file

@ -0,0 +1,41 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}
def format_public_cost_info(
cost_info: dict | None, usage_info: dict | None
) -> dict | None:
"""Return the legacy response shape without doing local cost accounting."""
duration = None
if usage_info and usage_info.get("call_duration_seconds") is not None:
duration = int(round(usage_info.get("call_duration_seconds") or 0))
elif cost_info and cost_info.get("call_duration_seconds") is not None:
duration = int(round(cost_info.get("call_duration_seconds") or 0))
dograh_token_usage = 0
if cost_info:
if "dograh_token_usage" in cost_info:
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
elif "total_cost_usd" in cost_info:
dograh_token_usage = round(
float(cost_info.get("total_cost_usd", 0)) * 100, 2
)
if duration is None and dograh_token_usage == 0:
return None
return {
"dograh_token_usage": dograh_token_usage,
"call_duration_seconds": duration,
}

View file

@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
base_initial_context = dict(workflow_run.initial_context or {})
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=base_initial_context,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = llm
runtime_configuration = {
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
**base_initial_context,
"runtime_configuration": runtime_configuration,
}
if mps_correlation_id:
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)

View file

@ -4,17 +4,11 @@ from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from api.db import db_client
from api.db.models import WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
)
from api.services.workflow.text_chat_logs import (
build_text_chat_realtime_feedback_events,
)
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
state=execution.state,
is_completed=execution.is_completed,
)
workflow_run = await db_client.get_workflow_run_by_id(run_id)
if workflow_run:
try:
# Apply the per-turn delta so org usage tracks cumulative run cost
# without replaying the full session totals on every turn.
await apply_usage_delta_to_organization(workflow_run, execution.usage)
except Exception as e:
logger.error(
f"Failed to update organization usage for text chat run {run_id}: {e}"
)
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is not None:
await db_client.update_workflow_run(run_id, cost_info=cost_info)
return await _reload_text_chat_session(run_id)

View file

@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
tracing_context=None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Create span with parent context
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Add result metadata to span
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
else:
# Tracing is disabled - perform retrieval without tracing
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
@ -220,6 +225,7 @@ async def _perform_retrieval(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
@ -272,11 +278,20 @@ async def _perform_retrieval(
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
results = await embedding_service.search_similar_chunks(

View file

@ -0,0 +1,111 @@
"""Workflow-run billing hooks.
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
For hosted deployments, Dograh reports completed platform usage to MPS.
When a server-minted MPS correlation id exists, MPS uses model-service usage
as the canonical duration. Otherwise Dograh reports the completed run duration.
"""
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.services.managed_model_services import get_mps_correlation_id
from api.services.mps_service_key_client import mps_service_key_client
def _workflow_run_organization_id(workflow_run) -> int | None:
workflow = getattr(workflow_run, "workflow", None)
return getattr(workflow, "organization_id", None)
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
duration = usage_info.get("call_duration_seconds")
try:
duration_seconds = float(duration)
except (TypeError, ValueError):
return None
return duration_seconds if duration_seconds > 0 else None
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
account = await mps_service_key_client.get_billing_account_status(
organization_id=organization_id
)
return bool(account and account.get("billing_mode") == "v2")
async def report_workflow_run_platform_usage(workflow_run) -> None:
"""Report hosted platform usage for a completed workflow run to MPS."""
if DEPLOYMENT_MODE == "oss":
return
if not getattr(workflow_run, "is_completed", False):
return
organization_id = _workflow_run_organization_id(workflow_run)
if organization_id is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no organization_id",
workflow_run.id,
)
return
correlation_id = get_mps_correlation_id(
getattr(workflow_run, "initial_context", None)
)
duration_seconds = (
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
)
if not correlation_id and duration_seconds is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no billable duration",
workflow_run.id,
)
return
try:
if not await _organization_uses_mps_billing_v2(organization_id):
return
result = await mps_service_key_client.report_platform_usage(
organization_id=organization_id,
correlation_id=correlation_id,
duration_seconds=duration_seconds,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": getattr(workflow_run, "workflow_id", None),
"duration_source": (
"mps_correlation" if correlation_id else "dograh_usage_info"
),
},
)
logger.info(
"Reported platform usage for workflow run {} to MPS: {}",
workflow_run.id,
result,
)
except Exception as e:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
"""Load a completed workflow run and report platform usage to MPS."""
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Skipping platform usage report: workflow run {} not found",
workflow_run_id,
)
return
await report_workflow_run_platform_usage(workflow_run)