mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: remove cost calculation from dograh codebase
This commit is contained in:
parent
7d4e2e06a9
commit
8f241b89d2
39 changed files with 1067 additions and 1460 deletions
|
|
@ -75,21 +75,21 @@ class UserConfigurationValidator:
|
|||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
|
|||
Service keys are stored and managed entirely in MPS, not in the local database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -420,6 +421,34 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def get_billing_account_status(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Get an existing MPS v2 billing account without creating one."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS billing account status: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS billing account status: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -454,6 +483,76 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def report_platform_usage(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
correlation_id: Optional[str] = None,
|
||||
duration_seconds: Optional[float] = None,
|
||||
workflow_run_id: int | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
max_attempts: int = 3,
|
||||
) -> dict:
|
||||
"""Report hosted Dograh platform usage for a completed workflow run."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise ValueError("OSS deployments must not report platform usage to MPS")
|
||||
if not correlation_id and duration_seconds is None:
|
||||
raise ValueError(
|
||||
"Platform usage reports require correlation_id or duration_seconds"
|
||||
)
|
||||
|
||||
payload: dict = {
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if correlation_id:
|
||||
payload["correlation_id"] = correlation_id
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
max_attempts = max(1, max_attempts)
|
||||
last_response: httpx.Response | None = None
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
response = await client.post(
|
||||
(
|
||||
f"{self.base_url}/api/v1/billing/accounts/"
|
||||
f"{organization_id}/platform-usage"
|
||||
),
|
||||
json=payload,
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
last_response = response
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
should_retry = (
|
||||
response.status_code == 409
|
||||
and "usage_not_ready" in response.text
|
||||
and attempt < max_attempts
|
||||
)
|
||||
if should_retry:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
"Failed to report platform usage: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to report platform usage: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
raise httpx.HTTPStatusError(
|
||||
"Failed to report platform usage",
|
||||
request=last_response.request,
|
||||
response=last_response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
|
|
|
|||
|
|
@ -162,15 +162,13 @@ async def run_pipeline_telephony(
|
|||
workflow_id: Workflow being executed.
|
||||
workflow_run_id: Workflow run row.
|
||||
user_id: Owner of the workflow.
|
||||
call_id: Provider call identifier (stored in cost_info for billing).
|
||||
call_id: Provider call identifier.
|
||||
transport_kwargs: Provider-specific kwargs forwarded to the transport
|
||||
factory (e.g. stream_sid + call_sid for Twilio).
|
||||
"""
|
||||
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
|
||||
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if workflow:
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
# Pricing Module
|
||||
|
||||
This module contains pricing models and registries for different AI services used in workflow cost calculations.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
pricing/
|
||||
├── __init__.py # Main module exports
|
||||
├── models.py # Base pricing model classes
|
||||
├── llm.py # LLM pricing configurations
|
||||
├── tts.py # TTS pricing configurations
|
||||
├── stt.py # STT pricing configurations
|
||||
├── registry.py # Combined pricing registry
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Pricing Models
|
||||
|
||||
### TokenPricingModel
|
||||
Used for LLM services that charge based on tokens:
|
||||
- `prompt_token_price`: Cost per prompt token
|
||||
- `completion_token_price`: Cost per completion token
|
||||
- `cache_read_discount`: Discount for cache read tokens (default 50%)
|
||||
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
|
||||
|
||||
### CharacterPricingModel
|
||||
Used for TTS services that charge based on character count:
|
||||
- `character_price`: Cost per character
|
||||
|
||||
### TimePricingModel
|
||||
Used for STT services that charge based on time:
|
||||
- `second_price`: Cost per second
|
||||
|
||||
## Adding New Pricing
|
||||
|
||||
### Adding a New LLM Model
|
||||
Edit `llm.py` and add the model to the appropriate provider:
|
||||
|
||||
```python
|
||||
ServiceProviders.OPENAI: {
|
||||
"new-model": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000,
|
||||
completion_token_price=Decimal("8.00") / 1000000,
|
||||
),
|
||||
# ... existing models
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a New Provider
|
||||
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
|
||||
2. The registry will automatically include them
|
||||
|
||||
### Adding a New Service Type
|
||||
1. Create a new pricing file (e.g., `image.py`)
|
||||
2. Define the pricing models
|
||||
3. Import and add to `registry.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The pricing registry is automatically imported and used by the cost calculator:
|
||||
|
||||
```python
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.workflow.cost_calculator import cost_calculator
|
||||
|
||||
# The cost calculator uses the pricing registry automatically
|
||||
result = cost_calculator.calculate_total_cost(usage_info)
|
||||
```
|
||||
|
||||
## Maintenance
|
||||
|
||||
- Update pricing when providers change their rates
|
||||
- All prices should use `Decimal` for precision
|
||||
- Include comments with current pricing from provider documentation
|
||||
- Test changes with existing test suite
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""
|
||||
Pricing module for workflow cost calculation.
|
||||
|
||||
This module contains pricing models and registries for different AI services.
|
||||
"""
|
||||
|
||||
from .registry import PRICING_REGISTRY
|
||||
|
||||
__all__ = ["PRICING_REGISTRY"]
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
"""
|
||||
Cost Calculator for Workflow Runs
|
||||
|
||||
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
|
||||
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
|
||||
|
||||
Features:
|
||||
- Token-based pricing for LLM services with cache optimization support
|
||||
- Character-based pricing for TTS services
|
||||
- Time-based pricing for STT services
|
||||
- Configurable pricing models that can be updated
|
||||
- Support for multiple providers and models
|
||||
- Automatic provider inference from model names
|
||||
- JSON serialization support for database storage
|
||||
|
||||
Usage:
|
||||
from api.tasks.cost_calculator import cost_calculator
|
||||
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"processor_name|||gpt-4o": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"total_tokens": 1500,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"processor_name|||aura-2-helena-en": 2000 # character count
|
||||
}
|
||||
}
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
print(f"Total cost: ${cost_breakdown['total']:.6f}")
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.pricing.models import (
|
||||
PricingModel,
|
||||
)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""Main cost calculator class"""
|
||||
|
||||
def __init__(self, pricing_registry: Dict = None):
|
||||
self.pricing_registry = pricing_registry or PRICING_REGISTRY
|
||||
|
||||
def get_pricing_model(
|
||||
self, service_type: str, provider: str, model: str
|
||||
) -> Optional[PricingModel]:
|
||||
"""Get pricing model for a specific service, provider, and model"""
|
||||
try:
|
||||
service_pricing = self.pricing_registry.get(service_type, {})
|
||||
|
||||
# Try to get pricing for the specific provider
|
||||
provider_pricing = service_pricing.get(provider, {})
|
||||
pricing_model = provider_pricing.get(model) or provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
if pricing_model:
|
||||
return pricing_model
|
||||
|
||||
# If not found, try the "default" provider for this service type
|
||||
default_provider_pricing = service_pricing.get("default", {})
|
||||
return default_provider_pricing.get(model) or default_provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
except (KeyError, AttributeError):
|
||||
return None
|
||||
|
||||
def calculate_llm_cost(
|
||||
self, provider: str, model: str, usage: Dict[str, int]
|
||||
) -> Decimal:
|
||||
"""Calculate cost for LLM usage"""
|
||||
pricing_model = self.get_pricing_model("llm", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(usage)
|
||||
|
||||
def calculate_tts_cost(
|
||||
self, provider: str, model: str, character_count: int
|
||||
) -> Decimal:
|
||||
"""Calculate cost for TTS usage"""
|
||||
pricing_model = self.get_pricing_model("tts", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(character_count)
|
||||
|
||||
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT usage"""
|
||||
pricing_model = self.get_pricing_model("stt", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(seconds)
|
||||
|
||||
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
|
||||
llm_cost_total = Decimal("0")
|
||||
tts_cost_total = Decimal("0")
|
||||
stt_cost_total = Decimal("0")
|
||||
|
||||
# Calculate LLM costs
|
||||
llm_usage = usage_info.get("llm", {})
|
||||
for key, usage in llm_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Try to determine provider from processor name or model
|
||||
provider = self._infer_provider_from_model(model, "llm")
|
||||
cost = self.calculate_llm_cost(provider, model, usage)
|
||||
llm_cost_total += cost
|
||||
|
||||
# Calculate TTS costs
|
||||
tts_usage = usage_info.get("tts", {})
|
||||
for key, character_count in tts_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Handle the case where model is "None" - infer from processor
|
||||
if model.lower() in ["none", "null", ""]:
|
||||
provider = self._infer_provider_from_processor(processor, "tts")
|
||||
model = "default" # Use default model for the provider
|
||||
else:
|
||||
provider = self._infer_provider_from_model(model, "tts")
|
||||
cost = self.calculate_tts_cost(provider, model, character_count)
|
||||
tts_cost_total += cost
|
||||
|
||||
# Calculate STT costs from explicit stt usage
|
||||
stt_usage = usage_info.get("stt", {})
|
||||
for key, seconds in stt_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
provider = self._infer_provider_from_model(model, "stt")
|
||||
cost = self.calculate_stt_cost(provider, model, seconds)
|
||||
stt_cost_total += cost
|
||||
|
||||
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
|
||||
|
||||
return {
|
||||
"llm_cost": float(llm_cost_total),
|
||||
"tts_cost": float(tts_cost_total),
|
||||
"stt_cost": float(stt_cost_total),
|
||||
"total": float(total_cost),
|
||||
}
|
||||
|
||||
def _parse_key(self, key) -> Tuple[str, str]:
|
||||
"""Parse key which is in format 'processor|||model'"""
|
||||
if isinstance(key, str) and "|||" in key:
|
||||
parts = key.split("|||", 1)
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
# Fallback for backwards compatibility or malformed keys
|
||||
return str(key), "unknown"
|
||||
|
||||
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
|
||||
"""Infer provider from model name"""
|
||||
if not model:
|
||||
return "unknown"
|
||||
|
||||
model_lower = model.lower()
|
||||
|
||||
# OpenAI models
|
||||
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq models
|
||||
if any(keyword in model_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Elevenlabs models
|
||||
if any(keyword in model_lower for keyword in ["eleven"]):
|
||||
return ServiceProviders.ELEVENLABS
|
||||
|
||||
# Deepgram models
|
||||
if any(
|
||||
keyword in model_lower
|
||||
for keyword in ["deepgram", "nova", "phonecall", "general"]
|
||||
):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
|
||||
"""Infer provider from processor name"""
|
||||
if not processor:
|
||||
return "unknown"
|
||||
|
||||
processor_lower = processor.lower()
|
||||
|
||||
# OpenAI processors
|
||||
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq processors
|
||||
if any(keyword in processor_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Deepgram processors
|
||||
if any(keyword in processor_lower for keyword in ["deepgram"]):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def update_pricing(
|
||||
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
|
||||
):
|
||||
"""Update pricing for a specific service/provider/model combination"""
|
||||
if service_type not in self.pricing_registry:
|
||||
self.pricing_registry[service_type] = {}
|
||||
if provider not in self.pricing_registry[service_type]:
|
||||
self.pricing_registry[service_type][provider] = {}
|
||||
self.pricing_registry[service_type][provider][model] = pricing_model
|
||||
|
||||
|
||||
# Global cost calculator instance
|
||||
cost_calculator = CostCalculator()
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
"""
|
||||
LLM pricing models for different providers.
|
||||
|
||||
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TokenPricingModel
|
||||
|
||||
# LLM pricing registry
|
||||
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-3.5-turbo": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
|
||||
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
|
||||
),
|
||||
"gpt-4": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
|
||||
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
|
||||
),
|
||||
"gpt-4.1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
|
||||
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-nano": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
|
||||
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
),
|
||||
"gpt-4.5-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
|
||||
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
|
||||
completion_token_price=Decimal("10.00")
|
||||
/ 1000000, # $10.00 per 1M tokens - FIXED
|
||||
),
|
||||
"gpt-4o-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
|
||||
),
|
||||
"gpt-4o-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"o1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
|
||||
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
|
||||
),
|
||||
"o1-pro": TokenPricingModel(
|
||||
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
|
||||
),
|
||||
"o1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o3": TokenPricingModel(
|
||||
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
|
||||
),
|
||||
"o3-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o4-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"computer-use-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
|
||||
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
|
||||
),
|
||||
"gpt-image-1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
|
||||
),
|
||||
"codex-mini-latest": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
|
||||
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
|
||||
),
|
||||
# Transcription models
|
||||
"gpt-4o-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
|
||||
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
),
|
||||
# TTS models with token-based pricing
|
||||
"gpt-4o-mini-tts": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("0")
|
||||
/ 1000000, # No completion tokens for TTS
|
||||
),
|
||||
},
|
||||
ServiceProviders.GROQ: {
|
||||
"llama-3.3-70b-versatile": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
|
||||
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
|
||||
),
|
||||
"deepseek-r1-distill-llama-70b": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
|
||||
completion_token_price=Decimal("0.00079") / 1000,
|
||||
),
|
||||
},
|
||||
ServiceProviders.AZURE: {
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("8.80")
|
||||
/ 1000000, # $1.60 per 1M tokens if using data zone
|
||||
)
|
||||
},
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
Base pricing models for different service types.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CostType(Enum):
|
||||
LLM_TOKENS = "llm_tokens"
|
||||
TTS_CHARACTERS = "tts_characters"
|
||||
STT_SECONDS = "stt_seconds"
|
||||
|
||||
|
||||
class PricingModel:
|
||||
"""Base class for pricing models"""
|
||||
|
||||
def calculate_cost(self, usage: Any) -> Decimal:
|
||||
"""Calculate cost based on usage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenPricingModel(PricingModel):
|
||||
"""Pricing model for token-based services (LLM)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_price: Decimal,
|
||||
completion_token_price: Decimal,
|
||||
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
|
||||
cache_creation_multiplier: Decimal = Decimal(
|
||||
"1.25"
|
||||
), # 25% premium for cache creation
|
||||
):
|
||||
self.prompt_token_price = prompt_token_price
|
||||
self.completion_token_price = completion_token_price
|
||||
self.cache_read_discount = cache_read_discount
|
||||
self.cache_creation_multiplier = cache_creation_multiplier
|
||||
|
||||
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
|
||||
"""Calculate cost for LLM token usage"""
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
|
||||
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
|
||||
|
||||
# Base cost
|
||||
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
|
||||
completion_cost = Decimal(completion_tokens) * self.completion_token_price
|
||||
|
||||
# Cache adjustments
|
||||
cache_read_savings = (
|
||||
Decimal(cache_read_tokens)
|
||||
* self.prompt_token_price
|
||||
* self.cache_read_discount
|
||||
)
|
||||
cache_creation_premium = (
|
||||
Decimal(cache_creation_tokens)
|
||||
* self.prompt_token_price
|
||||
* (self.cache_creation_multiplier - 1)
|
||||
)
|
||||
|
||||
total_cost = (
|
||||
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
|
||||
)
|
||||
return max(total_cost, Decimal("0")) # Ensure non-negative
|
||||
|
||||
|
||||
class CharacterPricingModel(PricingModel):
|
||||
"""Pricing model for character-based services (TTS)"""
|
||||
|
||||
def __init__(self, character_price: Decimal):
|
||||
self.character_price = character_price
|
||||
|
||||
def calculate_cost(self, character_count: int) -> Decimal:
|
||||
"""Calculate cost for TTS character usage"""
|
||||
return Decimal(character_count) * self.character_price
|
||||
|
||||
|
||||
class TimePricingModel(PricingModel):
|
||||
"""Pricing model for time-based services (STT)"""
|
||||
|
||||
def __init__(self, second_price: Decimal):
|
||||
self.second_price = second_price
|
||||
|
||||
def calculate_cost(self, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT time usage"""
|
||||
return Decimal(str(seconds)) * self.second_price
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Main pricing registry that combines all service type pricing models.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
||||
# Combined pricing registry for all service types
|
||||
PRICING_REGISTRY: Dict = {
|
||||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
STT (Speech-to-Text) pricing models for different providers.
|
||||
|
||||
Prices are per second for STT services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TimePricingModel
|
||||
|
||||
# STT pricing registry
|
||||
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
|
||||
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
|
||||
"default": TimePricingModel(Decimal("0.0077") / 60),
|
||||
},
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
|
||||
"default": TimePricingModel(Decimal("0.015") / 60),
|
||||
},
|
||||
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""
|
||||
TTS (Text-to-Speech) pricing models for different providers.
|
||||
|
||||
Prices are per character for TTS services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import CharacterPricingModel
|
||||
|
||||
# TTS pricing registry
|
||||
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
},
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
|
||||
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
},
|
||||
ServiceProviders.ELEVENLABS: {
|
||||
# 6400 usd per 250*1e6 characters
|
||||
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
|
||||
},
|
||||
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
|
||||
}
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider_for_run
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _get_pricing_organization(workflow_run):
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
organization_id = getattr(workflow, "organization_id", None)
|
||||
if organization_id is None and workflow and workflow.user:
|
||||
organization_id = workflow.user.selected_organization_id
|
||||
if organization_id is None:
|
||||
return None
|
||||
return await db_client.get_organization_by_id(organization_id)
|
||||
|
||||
|
||||
async def _build_usage_cost_snapshot(
|
||||
usage_info: dict | None,
|
||||
*,
|
||||
workflow_run=None,
|
||||
include_telephony_cost: bool = False,
|
||||
organization=None,
|
||||
calculated_at: str | None = None,
|
||||
) -> dict | None:
|
||||
if not usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return None
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
|
||||
if include_telephony_cost and workflow_run is not None:
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
total_cost_usd = Decimal(str(cost_breakdown["total"]))
|
||||
dograh_tokens = float(total_cost_usd * Decimal("100"))
|
||||
|
||||
if organization is None and workflow_run is not None:
|
||||
organization = await _get_pricing_organization(workflow_run)
|
||||
|
||||
charge_usd = None
|
||||
if organization and organization.price_per_second_usd:
|
||||
duration_seconds = usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = float(
|
||||
Decimal(str(duration_seconds))
|
||||
* Decimal(str(organization.price_per_second_usd))
|
||||
)
|
||||
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(total_cost_usd),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": calculated_at
|
||||
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
|
||||
}
|
||||
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = organization.price_per_second_usd
|
||||
|
||||
return cost_info
|
||||
|
||||
|
||||
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
|
||||
cost_info = await _build_usage_cost_snapshot(
|
||||
workflow_run.usage_info,
|
||||
workflow_run=workflow_run,
|
||||
include_telephony_cost=True,
|
||||
calculated_at=workflow_run.created_at.isoformat(),
|
||||
)
|
||||
if cost_info is None:
|
||||
return None
|
||||
return {
|
||||
**(workflow_run.cost_info or {}),
|
||||
**cost_info,
|
||||
}
|
||||
|
||||
|
||||
async def save_workflow_run_cost_info(
|
||||
workflow_run_id: int, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
|
||||
async def apply_workflow_run_usage_to_organization(
|
||||
workflow_run, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
|
||||
|
||||
async def apply_usage_delta_to_organization(
|
||||
workflow_run, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return None
|
||||
|
||||
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
|
||||
if cost_info is None:
|
||||
return None
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
return cost_info
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
try:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
await save_workflow_run_cost_info(workflow_run_id, cost_info)
|
||||
|
||||
try:
|
||||
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
|
||||
except Exception as e:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if org:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to update organization usage: {e}")
|
||||
# Don't fail the whole cost calculation if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
|
|
@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
for run in runs:
|
||||
initial = run.initial_context or {}
|
||||
gathered = run.gathered_context or {}
|
||||
cost = run.cost_info or {}
|
||||
usage = run.usage_info or {}
|
||||
|
||||
call_tags = gathered.get("call_tags", [])
|
||||
if isinstance(call_tags, list):
|
||||
|
|
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
run.created_at.isoformat() if run.created_at else "",
|
||||
initial.get("phone_number", ""),
|
||||
gathered.get("mapped_call_disposition", ""),
|
||||
cost.get("call_duration_seconds", ""),
|
||||
usage.get("call_duration_seconds", ""),
|
||||
]
|
||||
|
||||
extracted = gathered.get("extracted_variables", {})
|
||||
|
|
|
|||
|
|
@ -66,34 +66,6 @@ async def handle_vonage_events(
|
|||
logger.error(f"[run {workflow_run_id}] Workflow run not found")
|
||||
return {"status": "error", "message": "Workflow run not found"}
|
||||
|
||||
# For a completed call that includes cost info, capture it immediately
|
||||
if event_data.get("status") == "completed":
|
||||
# Vonage sometimes includes price info in the webhook
|
||||
if "price" in event_data or "rate" in event_data:
|
||||
try:
|
||||
if workflow_run.cost_info:
|
||||
# Store immediate cost info if available
|
||||
cost_info = workflow_run.cost_info.copy()
|
||||
if "price" in event_data:
|
||||
cost_info["vonage_webhook_price"] = float(event_data["price"])
|
||||
if "rate" in event_data:
|
||||
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
|
||||
if "duration" in event_data:
|
||||
cost_info["vonage_webhook_duration"] = int(
|
||||
event_data["duration"]
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, cost_info=cost_info
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
|
||||
)
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
|
|
|
|||
41
api/services/workflow/run_usage_response.py
Normal file
41
api/services/workflow/run_usage_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
||||
|
||||
def format_public_cost_info(
|
||||
cost_info: dict | None, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
"""Return the legacy response shape without doing local cost accounting."""
|
||||
duration = None
|
||||
if usage_info and usage_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(usage_info.get("call_duration_seconds") or 0))
|
||||
elif cost_info and cost_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(cost_info.get("call_duration_seconds") or 0))
|
||||
|
||||
dograh_token_usage = 0
|
||||
if cost_info:
|
||||
if "dograh_token_usage" in cost_info:
|
||||
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
|
||||
elif "total_cost_usd" in cost_info:
|
||||
dograh_token_usage = round(
|
||||
float(cost_info.get("total_cost_usd", 0)) * 100, 2
|
||||
)
|
||||
|
||||
if duration is None and dograh_token_usage == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"dograh_token_usage": dograh_token_usage,
|
||||
"call_duration_seconds": duration,
|
||||
}
|
||||
|
|
@ -4,17 +4,11 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
|
|
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
|
|||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
|
|
|||
111
api/services/workflow_run_billing.py
Normal file
111
api/services/workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Workflow-run billing hooks.
|
||||
|
||||
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
|
||||
For hosted deployments, Dograh reports completed platform usage to MPS.
|
||||
When a server-minted MPS correlation id exists, MPS uses model-service usage
|
||||
as the canonical duration. Otherwise Dograh reports the completed run duration.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.services.managed_model_services import get_mps_correlation_id
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
def _workflow_run_organization_id(workflow_run) -> int | None:
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
return getattr(workflow, "organization_id", None)
|
||||
|
||||
|
||||
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
|
||||
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
|
||||
duration = usage_info.get("call_duration_seconds")
|
||||
try:
|
||||
duration_seconds = float(duration)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return duration_seconds if duration_seconds > 0 else None
|
||||
|
||||
|
||||
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
|
||||
account = await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id
|
||||
)
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def report_workflow_run_platform_usage(workflow_run) -> None:
|
||||
"""Report hosted platform usage for a completed workflow run to MPS."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return
|
||||
|
||||
if not getattr(workflow_run, "is_completed", False):
|
||||
return
|
||||
|
||||
organization_id = _workflow_run_organization_id(workflow_run)
|
||||
if organization_id is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no organization_id",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
correlation_id = get_mps_correlation_id(
|
||||
getattr(workflow_run, "initial_context", None)
|
||||
)
|
||||
duration_seconds = (
|
||||
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
|
||||
)
|
||||
if not correlation_id and duration_seconds is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no billable duration",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if not await _organization_uses_mps_billing_v2(organization_id):
|
||||
return
|
||||
|
||||
result = await mps_service_key_client.report_platform_usage(
|
||||
organization_id=organization_id,
|
||||
correlation_id=correlation_id,
|
||||
duration_seconds=duration_seconds,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": getattr(workflow_run, "workflow_id", None),
|
||||
"duration_source": (
|
||||
"mps_correlation" if correlation_id else "dograh_usage_info"
|
||||
),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Reported platform usage for workflow run {} to MPS: {}",
|
||||
workflow_run.id,
|
||||
result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
|
||||
"""Load a completed workflow run and report platform usage to MPS."""
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Skipping platform usage report: workflow run {} not found",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
Loading…
Add table
Add a link
Reference in a new issue