From 8f241b89d237555a9c17ddb3634229f274853b5d Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Fri, 12 Jun 2026 13:26:33 +0530 Subject: [PATCH] fix: remove cost calculation from dograh codebase --- api/db/campaign_client.py | 27 +- api/db/filters.py | 6 +- api/db/organization_usage_client.py | 59 +---- api/db/workflow_run_client.py | 24 +- api/routes/organization_usage.py | 82 +++++-- api/routes/workflow.py | 22 +- api/services/configuration/check_validity.py | 18 +- api/services/mps_service_key_client.py | 99 ++++++++ api/services/pipecat/run_pipeline.py | 4 +- api/services/pricing/README.md | 76 ------ api/services/pricing/__init__.py | 9 - api/services/pricing/cost_calculator.py | 228 ----------------- api/services/pricing/embeddings.py | 44 ---- api/services/pricing/llm.py | 143 ----------- api/services/pricing/models.py | 89 ------- api/services/pricing/registry.py | 18 -- api/services/pricing/run_usage_response.py | 13 - api/services/pricing/stt.py | 26 -- api/services/pricing/tts.py | 30 --- api/services/pricing/workflow_run_cost.py | 230 ------------------ api/services/reports/run_report.py | 4 +- .../telephony/providers/vonage/routes.py | 28 --- api/services/workflow/run_usage_response.py | 41 ++++ .../workflow/text_chat_session_service.py | 20 -- api/services/workflow_run_billing.py | 111 +++++++++ api/tasks/arq.py | 6 +- api/tasks/s3_upload.py | 113 +-------- api/tasks/workflow_completion.py | 121 +++++++++ api/tests/test_ai_model_configuration_v2.py | 64 +++++ api/tests/test_cost_calculator.py | 31 --- api/tests/test_mps_service_key_client.py | 149 ++++++++++++ api/tests/test_organization_usage_billing.py | 33 +++ api/tests/test_run_usage_response.py | 2 +- api/tests/test_workflow_run_billing.py | 212 ++++++++++++++++ api/tests/test_workflow_run_cost.py | 181 -------------- api/tests/test_workflow_text_chat.py | 12 +- ui/src/app/billing/page.tsx | 93 ++++++- ui/src/client/types.gen.ts | 32 +++ ui/src/lib/apiError.ts | 27 +- 39 files changed, 1067 insertions(+), 1460 deletions(-) delete mode 100644 api/services/pricing/README.md delete mode 100644 api/services/pricing/__init__.py delete mode 100644 api/services/pricing/cost_calculator.py delete mode 100644 api/services/pricing/embeddings.py delete mode 100644 api/services/pricing/llm.py delete mode 100644 api/services/pricing/models.py delete mode 100644 api/services/pricing/registry.py delete mode 100644 api/services/pricing/run_usage_response.py delete mode 100644 api/services/pricing/stt.py delete mode 100644 api/services/pricing/tts.py delete mode 100644 api/services/pricing/workflow_run_cost.py create mode 100644 api/services/workflow/run_usage_response.py create mode 100644 api/services/workflow_run_billing.py create mode 100644 api/tasks/workflow_completion.py delete mode 100644 api/tests/test_cost_calculator.py create mode 100644 api/tests/test_organization_usage_billing.py create mode 100644 api/tests/test_workflow_run_billing.py delete mode 100644 api/tests/test_workflow_run_cost.py diff --git a/api/db/campaign_client.py b/api/db/campaign_client.py index d9ff2bae..7729a2ea 100644 --- a/api/db/campaign_client.py +++ b/api/db/campaign_client.py @@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel from api.schemas.workflow import WorkflowRunResponseSchema +from api.services.workflow.run_usage_response import format_public_cost_info class CampaignClient(BaseDBClient): @@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient): "is_completed": run.is_completed, "recording_url": run.recording_url, "transcript_url": run.transcript_url, - "cost_info": { - "dograh_token_usage": ( - run.cost_info.get("dograh_token_usage") - if run.cost_info - and "dograh_token_usage" in run.cost_info - else round( - float(run.cost_info.get("total_cost_usd", 0)) * 100, - 2, - ) - if run.cost_info and "total_cost_usd" in run.cost_info - else 0 - ), - "call_duration_seconds": int( - round(run.cost_info.get("call_duration_seconds") or 0) - ) - if run.cost_info - else None, - } - if run.cost_info - else None, + "cost_info": format_public_cost_info( + run.cost_info, run.usage_info + ), "definition_id": run.definition_id, "initial_context": run.initial_context, "gathered_context": run.gathered_context, @@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient): async with self.async_session() as session: conditions = [ WorkflowRunModel.is_completed.is_(True), - WorkflowRunModel.cost_info["call_duration_seconds"] + WorkflowRunModel.usage_info["call_duration_seconds"] .as_string() .isnot(None), ] @@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient): WorkflowRunModel.initial_context, WorkflowRunModel.gathered_context, WorkflowRunModel.cost_info, + WorkflowRunModel.usage_info, WorkflowRunModel.public_access_token, ) .where(*conditions) diff --git a/api/db/filters.py b/api/db/filters.py index e960d724..cd30b144 100644 --- a/api/db/filters.py +++ b/api/db/filters.py @@ -25,7 +25,7 @@ def get_workflow_run_order_clause( """ # Determine sort column if sort_by == "duration": - sort_column = WorkflowRunModel.cost_info.op("->>")( + sort_column = WorkflowRunModel.usage_info.op("->>")( "call_duration_seconds" ).cast(Float) else: @@ -43,7 +43,7 @@ def get_workflow_run_order_clause( ATTRIBUTE_FIELD_MAPPING = { "dateRange": "created_at", "dispositionCode": "gathered_context.mapped_call_disposition", - "duration": "cost_info.call_duration_seconds", + "duration": "usage_info.call_duration_seconds", "status": "is_completed", "tokenUsage": "cost_info.total_cost_usd", "runId": "id", @@ -208,7 +208,7 @@ def apply_workflow_run_filters( min_val = value.get("min") max_val = value.get("max") - if field == "cost_info.call_duration_seconds": + if field == "usage_info.call_duration_seconds": # Use ->> operator for compatibility with all PostgreSQL versions # (subscript [] only works in PostgreSQL 14+) duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")( diff --git a/api/db/organization_usage_client.py b/api/db/organization_usage_client.py index b147a2d6..dfca0538 100644 --- a/api/db/organization_usage_client.py +++ b/api/db/organization_usage_client.py @@ -96,43 +96,6 @@ class OrganizationUsageClient(BaseDBClient): ) return cycle_result.scalar_one() - async def update_usage_after_run( - self, - organization_id: int, - actual_tokens: float, - duration_seconds: float = 0, - charge_usd: float | None = None, - ) -> None: - """Update usage after a workflow run completes with actual token count and duration. - - This method is fully atomic and safe for concurrent access from multiple processes. - """ - async with self.async_session() as session: - # Get or create current cycle within the same session/transaction - cycle = await self._get_or_create_current_cycle_impl( - organization_id, session, commit=False - ) - - # Acquire a row-level lock for atomic update - result = await session.execute( - select(OrganizationUsageCycleModel) - .where(OrganizationUsageCycleModel.id == cycle.id) - .with_for_update(skip_locked=False) - ) - cycle_locked = result.scalar_one() - - # Update usage atomically - cycle_locked.used_dograh_tokens += actual_tokens - cycle_locked.total_duration_seconds += int(round(duration_seconds)) - - # Update USD amount if provided - if charge_usd is not None: - if cycle_locked.used_amount_usd is None: - cycle_locked.used_amount_usd = 0 - cycle_locked.used_amount_usd += charge_usd - - await session.commit() - async def get_current_usage(self, organization_id: int) -> dict: """Get current reporting-period usage information.""" async with self.async_session() as session: @@ -178,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient): .join(UserModel, WorkflowModel.user_id == UserModel.id) .where( UserModel.selected_organization_id == organization_id, - WorkflowRunModel.cost_info.isnot(None), + WorkflowRunModel.usage_info.isnot(None), ) .order_by(WorkflowRunModel.created_at.desc()) ) @@ -231,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient): total_tokens = 0 total_duration_seconds = 0 for run in runs: - if run.cost_info: - # Try to get dograh_token_usage first (new format) - dograh_tokens = run.cost_info.get("dograh_token_usage", 0) - # If not present, calculate from total_cost_usd (old format) - if dograh_tokens == 0 and "total_cost_usd" in run.cost_info: - dograh_tokens = round( - float(run.cost_info["total_cost_usd"]) * 100, 2 - ) - # Get call duration - call_duration = run.cost_info.get("call_duration_seconds", 0) - else: - dograh_tokens = 0 - call_duration = 0 + dograh_tokens = 0 + call_duration = (run.usage_info or {}).get("call_duration_seconds", 0) total_tokens += dograh_tokens total_duration_seconds += int(round(call_duration)) @@ -317,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient): WorkflowRunModel.initial_context, WorkflowRunModel.gathered_context, WorkflowRunModel.cost_info, + WorkflowRunModel.usage_info, WorkflowRunModel.public_access_token, ) .join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id) .join(UserModel, WorkflowModel.user_id == UserModel.id) .where( UserModel.selected_organization_id == organization_id, - WorkflowRunModel.cost_info.isnot(None), + WorkflowRunModel.usage_info.isnot(None), ) .order_by(WorkflowRunModel.created_at.desc()) ) @@ -418,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient): select( date_expr.label("date"), func.sum( - WorkflowRunModel.cost_info["call_duration_seconds"].as_float() + WorkflowRunModel.usage_info["call_duration_seconds"].as_float() ).label("total_seconds"), func.count(WorkflowRunModel.id).label("call_count"), ) diff --git a/api/db/workflow_run_client.py b/api/db/workflow_run_client.py index 57c3e02b..497230ad 100644 --- a/api/db/workflow_run_client.py +++ b/api/db/workflow_run_client.py @@ -16,6 +16,7 @@ from api.db.models import ( ) from api.enums import CallType, StorageBackend from api.schemas.workflow import WorkflowRunResponseSchema +from api.services.workflow.run_usage_response import format_public_cost_info class WorkflowRunClient(BaseDBClient): @@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient): "is_completed": run.is_completed, "recording_url": run.recording_url, "transcript_url": run.transcript_url, - "cost_info": { - "dograh_token_usage": ( - run.cost_info.get("dograh_token_usage") - if run.cost_info - and "dograh_token_usage" in run.cost_info - else round( - float(run.cost_info.get("total_cost_usd", 0)) * 100, - 2, - ) - if run.cost_info and "total_cost_usd" in run.cost_info - else 0 - ), - "call_duration_seconds": int( - round(run.cost_info.get("call_duration_seconds") or 0) - ) - if run.cost_info - else None, - } - if run.cost_info - else None, + "cost_info": format_public_cost_info( + run.cost_info, run.usage_info + ), "definition_id": run.definition_id, "initial_context": run.initial_context, "gathered_context": run.gathered_context, diff --git a/api/routes/organization_usage.py b/api/routes/organization_usage.py index 1ce76acf..b8e75c6d 100644 --- a/api/routes/organization_usage.py +++ b/api/routes/organization_usage.py @@ -11,9 +11,6 @@ from api.constants import DEPLOYMENT_MODE, UI_APP_URL from api.db import db_client from api.db.models import UserModel from api.services.auth.depends import get_user, get_user_with_selected_organization -from api.services.configuration.ai_model_configuration import ( - get_resolved_ai_model_configuration, -) from api.services.mps_service_key_client import mps_service_key_client from api.services.reports import generate_usage_runs_report_csv from api.utils.artifacts import artifact_url @@ -58,6 +55,14 @@ class MPSCreditLedgerEntryResponse(BaseModel): amount_minor: Optional[int] = None amount_currency: Optional[str] = None payment_order_id: Optional[int] = None + metric_code: Optional[str] = None + correlation_id: Optional[str] = None + aggregation_key: Optional[str] = None + usage_event_id: Optional[int] = None + workflow_run_id: Optional[int] = None + workflow_id: Optional[int] = None + billable_quantity: Optional[float] = None + quantity_unit: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) created_at: str @@ -71,6 +76,15 @@ class MPSBillingCreditsResponse(BaseModel): ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list) +def _optional_int(value: Any) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + class WorkflowRunUsageResponse(BaseModel): id: int workflow_id: int @@ -173,15 +187,17 @@ async def get_mps_credits(user: UserModel = Depends(get_user)): raise HTTPException(status_code=500, detail=str(e)) -async def _uses_mps_billing_v2(user: UserModel, organization_id: int) -> bool: - resolved = await get_resolved_ai_model_configuration( - user_id=user.id, +async def _get_mps_billing_account_status( + user: UserModel, organization_id: int +) -> Optional[dict]: + return await mps_service_key_client.get_billing_account_status( organization_id=organization_id, + created_by=str(user.provider_id), ) - return ( - resolved.source == "organization_v2" - and resolved.effective.managed_service_version == 2 - ) + + +def _is_mps_billing_v2(account: Optional[dict]) -> bool: + return bool(account and account.get("billing_mode") == "v2") async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse: @@ -217,7 +233,8 @@ async def get_billing_credits( return await _legacy_mps_credits_response(user) organization_id = user.selected_organization_id - if not await _uses_mps_billing_v2(user, organization_id): + account_status = await _get_mps_billing_account_status(user, organization_id) + if not _is_mps_billing_v2(account_status): return await _legacy_mps_credits_response(user) ledger = await mps_service_key_client.get_credit_ledger( @@ -227,6 +244,22 @@ async def get_billing_credits( ) account = ledger.get("account") or {} ledger_entries = ledger.get("ledger_entries") or [] + workflow_ids_by_run_id: dict[int, int] = {} + workflow_run_ids = { + workflow_run_id + for entry in ledger_entries + if (workflow_run_id := _optional_int(entry.get("workflow_run_id"))) + is not None + } + for workflow_run_id in workflow_run_ids: + workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) + if ( + workflow_run + and workflow_run.workflow + and workflow_run.workflow.organization_id == organization_id + ): + workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id + balance = float(account.get("cached_balance_credits") or 0.0) total_debits = sum( abs(float(entry.get("credits_delta") or 0.0)) @@ -256,6 +289,20 @@ async def get_billing_credits( amount_minor=entry.get("amount_minor"), amount_currency=entry.get("amount_currency"), payment_order_id=entry.get("payment_order_id"), + metric_code=entry.get("metric_code"), + correlation_id=entry.get("correlation_id"), + aggregation_key=entry.get("aggregation_key"), + usage_event_id=_optional_int(entry.get("usage_event_id")), + workflow_run_id=_optional_int(entry.get("workflow_run_id")), + workflow_id=workflow_ids_by_run_id.get( + _optional_int(entry.get("workflow_run_id")) + ) + if entry.get("workflow_run_id") is not None + else None, + billable_quantity=float(entry["billable_quantity"]) + if entry.get("billable_quantity") is not None + else None, + quantity_unit=entry.get("quantity_unit"), metadata=entry.get("metadata") or {}, created_at=str(entry["created_at"]), ) @@ -285,19 +332,12 @@ async def create_mps_credit_purchase_url( organization_id = user.selected_organization_id assert organization_id is not None - resolved = await get_resolved_ai_model_configuration( - user_id=user.id, - organization_id=organization_id, - ) - if ( - resolved.source != "organization_v2" - or resolved.effective.managed_service_version != 2 - ): + account_status = await _get_mps_billing_account_status(user, organization_id) + if not _is_mps_billing_v2(account_status): raise HTTPException( status_code=403, detail=( - "Credit purchases are available only for organizations using " - "Dograh managed model configuration v2" + "Credit purchases are available only for organizations using billing v2" ), ) diff --git a/api/routes/workflow.py b/api/routes/workflow.py index ab6f5da7..06e5fdf9 100644 --- a/api/routes/workflow.py +++ b/api/routes/workflow.py @@ -41,12 +41,15 @@ from api.services.configuration.resolve import ( ) from api.services.mps_service_key_client import mps_service_key_client from api.services.posthog_client import capture_event -from api.services.pricing.run_usage_response import format_public_usage_info from api.services.reports import generate_workflow_report_csv from api.services.storage import storage_fs from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition from api.services.workflow.duplicate import duplicate_workflow from api.services.workflow.errors import ItemKind, WorkflowError +from api.services.workflow.run_usage_response import ( + format_public_cost_info, + format_public_usage_info, +) from api.services.workflow.trigger_paths import ( TriggerPathIssue, ensure_trigger_paths, @@ -1266,22 +1269,7 @@ async def get_workflow_run( "transcript_public_url": artifact_url(public_access_token, "transcript"), "recording_public_url": artifact_url(public_access_token, "recording"), "public_access_token": public_access_token, - "cost_info": { - "dograh_token_usage": ( - run.cost_info.get("dograh_token_usage") - if run.cost_info and "dograh_token_usage" in run.cost_info - else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2) - if run.cost_info and "total_cost_usd" in run.cost_info - else 0 - ), - "call_duration_seconds": int( - round(run.cost_info.get("call_duration_seconds")) - ) - if run.cost_info and run.cost_info.get("call_duration_seconds") is not None - else None, - } - if run.cost_info - else None, + "cost_info": format_public_cost_info(run.cost_info, run.usage_info), "usage_info": format_public_usage_info(run.usage_info), "created_at": run.created_at, "definition_id": run.definition_id, diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index cc17481f..b1996879 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -75,21 +75,21 @@ class UserConfigurationValidator: status_list = [] status_list.extend(self._validate_service(configuration.llm, "llm")) - status_list.extend(self._validate_service(configuration.stt, "stt")) - status_list.extend(self._validate_service(configuration.tts, "tts")) - # Embeddings is optional - only validate if configured - status_list.extend( - self._validate_service( - configuration.embeddings, "embeddings", required=False - ) - ) - # Realtime is optional - only validate if is_realtime is enabled if configuration.is_realtime: status_list.extend( self._validate_service( configuration.realtime, "realtime", required=True ) ) + else: + status_list.extend(self._validate_service(configuration.stt, "stt")) + status_list.extend(self._validate_service(configuration.tts, "tts")) + # Embeddings is optional - only validate if configured + status_list.extend( + self._validate_service( + configuration.embeddings, "embeddings", required=False + ) + ) if status_list: raise ValueError(status_list) diff --git a/api/services/mps_service_key_client.py b/api/services/mps_service_key_client.py index ded89b29..84c788a0 100644 --- a/api/services/mps_service_key_client.py +++ b/api/services/mps_service_key_client.py @@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana Service keys are stored and managed entirely in MPS, not in the local database. """ +import asyncio from typing import List, Optional import httpx @@ -420,6 +421,34 @@ class MPSServiceKeyClient: response=response, ) + async def get_billing_account_status( + self, + organization_id: int, + created_by: Optional[str] = None, + ) -> Optional[dict]: + """Get an existing MPS v2 billing account without creating one.""" + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status", + headers=self._get_headers( + organization_id=organization_id, + created_by=created_by, + ), + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to get MPS billing account status: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to get MPS billing account status: {response.text}", + request=response.request, + response=response, + ) + async def create_correlation_id( self, *, @@ -454,6 +483,76 @@ class MPSServiceKeyClient: response=response, ) + async def report_platform_usage( + self, + *, + organization_id: int, + correlation_id: Optional[str] = None, + duration_seconds: Optional[float] = None, + workflow_run_id: int | None = None, + metadata: Optional[dict] = None, + max_attempts: int = 3, + ) -> dict: + """Report hosted Dograh platform usage for a completed workflow run.""" + if DEPLOYMENT_MODE == "oss": + raise ValueError("OSS deployments must not report platform usage to MPS") + if not correlation_id and duration_seconds is None: + raise ValueError( + "Platform usage reports require correlation_id or duration_seconds" + ) + + payload: dict = { + "metadata": metadata or {}, + } + if correlation_id: + payload["correlation_id"] = correlation_id + if duration_seconds is not None: + payload["duration_seconds"] = duration_seconds + if workflow_run_id is not None: + payload["workflow_run_id"] = workflow_run_id + + max_attempts = max(1, max_attempts) + last_response: httpx.Response | None = None + async with httpx.AsyncClient(timeout=self.timeout) as client: + for attempt in range(1, max_attempts + 1): + response = await client.post( + ( + f"{self.base_url}/api/v1/billing/accounts/" + f"{organization_id}/platform-usage" + ), + json=payload, + headers=self._get_headers(organization_id=organization_id), + ) + last_response = response + + if response.status_code == 200: + return response.json() + + should_retry = ( + response.status_code == 409 + and "usage_not_ready" in response.text + and attempt < max_attempts + ) + if should_retry: + await asyncio.sleep(attempt) + continue + + logger.error( + "Failed to report platform usage: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to report platform usage: {response.text}", + request=response.request, + response=response, + ) + + raise httpx.HTTPStatusError( + "Failed to report platform usage", + request=last_response.request, + response=last_response, + ) + async def transcribe_audio( self, audio_data: bytes, diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index a5f2d077..07286901 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -162,15 +162,13 @@ async def run_pipeline_telephony( workflow_id: Workflow being executed. workflow_run_id: Workflow run row. user_id: Owner of the workflow. - call_id: Provider call identifier (stored in cost_info for billing). + call_id: Provider call identifier. transport_kwargs: Provider-specific kwargs forwarded to the transport factory (e.g. stream_sid + call_sid for Twilio). """ logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}") set_current_run_id(workflow_run_id) - await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id}) - workflow = await db_client.get_workflow(workflow_id, user_id) if workflow: set_current_org_id(workflow.organization_id) diff --git a/api/services/pricing/README.md b/api/services/pricing/README.md deleted file mode 100644 index 4f834c28..00000000 --- a/api/services/pricing/README.md +++ /dev/null @@ -1,76 +0,0 @@ -# Pricing Module - -This module contains pricing models and registries for different AI services used in workflow cost calculations. - -## Structure - -``` -pricing/ -├── __init__.py # Main module exports -├── models.py # Base pricing model classes -├── llm.py # LLM pricing configurations -├── tts.py # TTS pricing configurations -├── stt.py # STT pricing configurations -├── registry.py # Combined pricing registry -└── README.md # This file -``` - -## Pricing Models - -### TokenPricingModel -Used for LLM services that charge based on tokens: -- `prompt_token_price`: Cost per prompt token -- `completion_token_price`: Cost per completion token -- `cache_read_discount`: Discount for cache read tokens (default 50%) -- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%) - -### CharacterPricingModel -Used for TTS services that charge based on character count: -- `character_price`: Cost per character - -### TimePricingModel -Used for STT services that charge based on time: -- `second_price`: Cost per second - -## Adding New Pricing - -### Adding a New LLM Model -Edit `llm.py` and add the model to the appropriate provider: - -```python -ServiceProviders.OPENAI: { - "new-model": TokenPricingModel( - prompt_token_price=Decimal("2.00") / 1000000, - completion_token_price=Decimal("8.00") / 1000000, - ), - # ... existing models -} -``` - -### Adding a New Provider -1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py) -2. The registry will automatically include them - -### Adding a New Service Type -1. Create a new pricing file (e.g., `image.py`) -2. Define the pricing models -3. Import and add to `registry.py` - -## Usage - -The pricing registry is automatically imported and used by the cost calculator: - -```python -from api.services.pricing import PRICING_REGISTRY -from api.services.workflow.cost_calculator import cost_calculator - -# The cost calculator uses the pricing registry automatically -result = cost_calculator.calculate_total_cost(usage_info) -``` - -## Maintenance - -- Update pricing when providers change their rates -- All prices should use `Decimal` for precision -- Include comments with current pricing from provider documentation -- Test changes with existing test suite \ No newline at end of file diff --git a/api/services/pricing/__init__.py b/api/services/pricing/__init__.py deleted file mode 100644 index 1fa0eedf..00000000 --- a/api/services/pricing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Pricing module for workflow cost calculation. - -This module contains pricing models and registries for different AI services. -""" - -from .registry import PRICING_REGISTRY - -__all__ = ["PRICING_REGISTRY"] diff --git a/api/services/pricing/cost_calculator.py b/api/services/pricing/cost_calculator.py deleted file mode 100644 index 14344752..00000000 --- a/api/services/pricing/cost_calculator.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -Cost Calculator for Workflow Runs - -This module provides a comprehensive cost calculation system for workflow runs based on usage metrics -from different AI service providers (OpenAI, Groq, Deepgram, etc.). - -Features: -- Token-based pricing for LLM services with cache optimization support -- Character-based pricing for TTS services -- Time-based pricing for STT services -- Configurable pricing models that can be updated -- Support for multiple providers and models -- Automatic provider inference from model names -- JSON serialization support for database storage - -Usage: - from api.tasks.cost_calculator import cost_calculator - - usage_info = { - "llm": { - "processor_name|||gpt-4o": { - "prompt_tokens": 1000, - "completion_tokens": 500, - "total_tokens": 1500, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0 - } - }, - "tts": { - "processor_name|||aura-2-helena-en": 2000 # character count - } - } - - cost_breakdown = cost_calculator.calculate_total_cost(usage_info) - print(f"Total cost: ${cost_breakdown['total']:.6f}") -""" - -from decimal import Decimal -from typing import Any, Dict, Optional, Tuple - -from api.services.configuration.registry import ServiceProviders -from api.services.pricing import PRICING_REGISTRY -from api.services.pricing.models import ( - PricingModel, -) - - -class CostCalculator: - """Main cost calculator class""" - - def __init__(self, pricing_registry: Dict = None): - self.pricing_registry = pricing_registry or PRICING_REGISTRY - - def get_pricing_model( - self, service_type: str, provider: str, model: str - ) -> Optional[PricingModel]: - """Get pricing model for a specific service, provider, and model""" - try: - service_pricing = self.pricing_registry.get(service_type, {}) - - # Try to get pricing for the specific provider - provider_pricing = service_pricing.get(provider, {}) - pricing_model = provider_pricing.get(model) or provider_pricing.get( - "default" - ) - - if pricing_model: - return pricing_model - - # If not found, try the "default" provider for this service type - default_provider_pricing = service_pricing.get("default", {}) - return default_provider_pricing.get(model) or default_provider_pricing.get( - "default" - ) - - except (KeyError, AttributeError): - return None - - def calculate_llm_cost( - self, provider: str, model: str, usage: Dict[str, int] - ) -> Decimal: - """Calculate cost for LLM usage""" - pricing_model = self.get_pricing_model("llm", provider, model) - if not pricing_model: - return Decimal("0") - return pricing_model.calculate_cost(usage) - - def calculate_tts_cost( - self, provider: str, model: str, character_count: int - ) -> Decimal: - """Calculate cost for TTS usage""" - pricing_model = self.get_pricing_model("tts", provider, model) - if not pricing_model: - return Decimal("0") - return pricing_model.calculate_cost(character_count) - - def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal: - """Calculate cost for STT usage""" - pricing_model = self.get_pricing_model("stt", provider, model) - if not pricing_model: - return Decimal("0") - return pricing_model.calculate_cost(seconds) - - def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]: - llm_cost_total = Decimal("0") - tts_cost_total = Decimal("0") - stt_cost_total = Decimal("0") - - # Calculate LLM costs - llm_usage = usage_info.get("llm", {}) - for key, usage in llm_usage.items(): - processor, model = self._parse_key(key) - # Try to determine provider from processor name or model - provider = self._infer_provider_from_model(model, "llm") - cost = self.calculate_llm_cost(provider, model, usage) - llm_cost_total += cost - - # Calculate TTS costs - tts_usage = usage_info.get("tts", {}) - for key, character_count in tts_usage.items(): - processor, model = self._parse_key(key) - # Handle the case where model is "None" - infer from processor - if model.lower() in ["none", "null", ""]: - provider = self._infer_provider_from_processor(processor, "tts") - model = "default" # Use default model for the provider - else: - provider = self._infer_provider_from_model(model, "tts") - cost = self.calculate_tts_cost(provider, model, character_count) - tts_cost_total += cost - - # Calculate STT costs from explicit stt usage - stt_usage = usage_info.get("stt", {}) - for key, seconds in stt_usage.items(): - processor, model = self._parse_key(key) - provider = self._infer_provider_from_model(model, "stt") - cost = self.calculate_stt_cost(provider, model, seconds) - stt_cost_total += cost - - total_cost = llm_cost_total + tts_cost_total + stt_cost_total - - return { - "llm_cost": float(llm_cost_total), - "tts_cost": float(tts_cost_total), - "stt_cost": float(stt_cost_total), - "total": float(total_cost), - } - - def _parse_key(self, key) -> Tuple[str, str]: - """Parse key which is in format 'processor|||model'""" - if isinstance(key, str) and "|||" in key: - parts = key.split("|||", 1) - return parts[0], parts[1] - else: - # Fallback for backwards compatibility or malformed keys - return str(key), "unknown" - - def _infer_provider_from_model(self, model: str, service_type: str) -> str: - """Infer provider from model name""" - if not model: - return "unknown" - - model_lower = model.lower() - - # OpenAI models - if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]): - return ServiceProviders.OPENAI - - # Groq models - if any(keyword in model_lower for keyword in ["groq"]): - return ServiceProviders.GROQ - - # Elevenlabs models - if any(keyword in model_lower for keyword in ["eleven"]): - return ServiceProviders.ELEVENLABS - - # Deepgram models - if any( - keyword in model_lower - for keyword in ["deepgram", "nova", "phonecall", "general"] - ): - return ServiceProviders.DEEPGRAM - - # Default to first available provider for the service type - service_providers = self.pricing_registry.get(service_type, {}) - if service_providers: - return list(service_providers.keys())[0] - - return "unknown" - - def _infer_provider_from_processor(self, processor: str, service_type: str) -> str: - """Infer provider from processor name""" - if not processor: - return "unknown" - - processor_lower = processor.lower() - - # OpenAI processors - if any(keyword in processor_lower for keyword in ["openai", "gpt"]): - return ServiceProviders.OPENAI - - # Groq processors - if any(keyword in processor_lower for keyword in ["groq"]): - return ServiceProviders.GROQ - - # Deepgram processors - if any(keyword in processor_lower for keyword in ["deepgram"]): - return ServiceProviders.DEEPGRAM - - # Default to first available provider for the service type - service_providers = self.pricing_registry.get(service_type, {}) - if service_providers: - return list(service_providers.keys())[0] - - return "unknown" - - def update_pricing( - self, service_type: str, provider: str, model: str, pricing_model: PricingModel - ): - """Update pricing for a specific service/provider/model combination""" - if service_type not in self.pricing_registry: - self.pricing_registry[service_type] = {} - if provider not in self.pricing_registry[service_type]: - self.pricing_registry[service_type][provider] = {} - self.pricing_registry[service_type][provider][model] = pricing_model - - -# Global cost calculator instance -cost_calculator = CostCalculator() diff --git a/api/services/pricing/embeddings.py b/api/services/pricing/embeddings.py deleted file mode 100644 index a58a8caa..00000000 --- a/api/services/pricing/embeddings.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Embeddings pricing models for different providers. - -Prices are per token for embedding models. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import PricingModel - - -class EmbeddingPricingModel(PricingModel): - """Pricing model for token-based embedding services.""" - - def __init__(self, token_price: Decimal): - """Initialize with price per token. - - Args: - token_price: Cost per token for embedding - """ - self.token_price = token_price - - def calculate_cost(self, token_count: int) -> Decimal: - """Calculate cost for embedding token usage.""" - return Decimal(token_count) * self.token_price - - -# Embeddings pricing registry -EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = { - ServiceProviders.OPENAI: { - "text-embedding-3-small": EmbeddingPricingModel( - token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens - ), - "text-embedding-3-large": EmbeddingPricingModel( - token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens - ), - "text-embedding-ada-002": EmbeddingPricingModel( - token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy) - ), - }, -} diff --git a/api/services/pricing/llm.py b/api/services/pricing/llm.py deleted file mode 100644 index addb59bc..00000000 --- a/api/services/pricing/llm.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -LLM pricing models for different providers. - -Prices are per 1000 tokens for most models, with some newer models priced per million tokens. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import TokenPricingModel - -# LLM pricing registry -LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = { - ServiceProviders.OPENAI: { - "gpt-3.5-turbo": TokenPricingModel( - prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens - completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens - ), - "gpt-4": TokenPricingModel( - prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens - completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens - ), - "gpt-4.1": TokenPricingModel( - prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens - completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens - ), - "gpt-4.1-mini": TokenPricingModel( - prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens - completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens - ), - "gpt-4.1-nano": TokenPricingModel( - prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens - completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens - ), - "gpt-4.5-preview": TokenPricingModel( - prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens - completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens - ), - "gpt-4o": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED - completion_token_price=Decimal("10.00") - / 1000000, # $10.00 per 1M tokens - FIXED - ), - "gpt-4o-audio-preview": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - ), - "gpt-4o-realtime-preview": TokenPricingModel( - prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens - completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens - ), - "gpt-4o-mini": TokenPricingModel( - prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens - completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - ), - "gpt-4o-mini-audio-preview": TokenPricingModel( - prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens - completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - ), - "gpt-4o-mini-realtime-preview": TokenPricingModel( - prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens - ), - "gpt-4o-search-preview": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - ), - "gpt-4o-mini-search-preview": TokenPricingModel( - prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens - completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - ), - "o1": TokenPricingModel( - prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens - completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens - ), - "o1-pro": TokenPricingModel( - prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens - completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens - ), - "o1-mini": TokenPricingModel( - prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens - completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens - ), - "o3": TokenPricingModel( - prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens - ), - "o3-mini": TokenPricingModel( - prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens - completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens - ), - "o4-mini": TokenPricingModel( - prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens - completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens - ), - "computer-use-preview": TokenPricingModel( - prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens - completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens - ), - "gpt-image-1": TokenPricingModel( - prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens - completion_token_price=Decimal("0") / 1000000, # No output pricing shown - ), - "codex-mini-latest": TokenPricingModel( - prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens - completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens - ), - # Transcription models - "gpt-4o-transcribe": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - ), - "gpt-4o-mini-transcribe": TokenPricingModel( - prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens - completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens - ), - # TTS models with token-based pricing - "gpt-4o-mini-tts": TokenPricingModel( - prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - completion_token_price=Decimal("0") - / 1000000, # No completion tokens for TTS - ), - }, - ServiceProviders.GROQ: { - "llama-3.3-70b-versatile": TokenPricingModel( - prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens - completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens - ), - "deepseek-r1-distill-llama-70b": TokenPricingModel( - prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing - completion_token_price=Decimal("0.00079") / 1000, - ), - }, - ServiceProviders.AZURE: { - "gpt-4.1-mini": TokenPricingModel( - prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens - completion_token_price=Decimal("8.80") - / 1000000, # $1.60 per 1M tokens if using data zone - ) - }, -} diff --git a/api/services/pricing/models.py b/api/services/pricing/models.py deleted file mode 100644 index 58e197ac..00000000 --- a/api/services/pricing/models.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Base pricing models for different service types. -""" - -from decimal import Decimal -from enum import Enum -from typing import Any, Dict - - -class CostType(Enum): - LLM_TOKENS = "llm_tokens" - TTS_CHARACTERS = "tts_characters" - STT_SECONDS = "stt_seconds" - - -class PricingModel: - """Base class for pricing models""" - - def calculate_cost(self, usage: Any) -> Decimal: - """Calculate cost based on usage""" - raise NotImplementedError - - -class TokenPricingModel(PricingModel): - """Pricing model for token-based services (LLM)""" - - def __init__( - self, - prompt_token_price: Decimal, - completion_token_price: Decimal, - cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads - cache_creation_multiplier: Decimal = Decimal( - "1.25" - ), # 25% premium for cache creation - ): - self.prompt_token_price = prompt_token_price - self.completion_token_price = completion_token_price - self.cache_read_discount = cache_read_discount - self.cache_creation_multiplier = cache_creation_multiplier - - def calculate_cost(self, usage: Dict[str, int]) -> Decimal: - """Calculate cost for LLM token usage""" - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - cache_read_tokens = usage.get("cache_read_input_tokens") or 0 - cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0 - - # Base cost - prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price - completion_cost = Decimal(completion_tokens) * self.completion_token_price - - # Cache adjustments - cache_read_savings = ( - Decimal(cache_read_tokens) - * self.prompt_token_price - * self.cache_read_discount - ) - cache_creation_premium = ( - Decimal(cache_creation_tokens) - * self.prompt_token_price - * (self.cache_creation_multiplier - 1) - ) - - total_cost = ( - prompt_cost + completion_cost - cache_read_savings + cache_creation_premium - ) - return max(total_cost, Decimal("0")) # Ensure non-negative - - -class CharacterPricingModel(PricingModel): - """Pricing model for character-based services (TTS)""" - - def __init__(self, character_price: Decimal): - self.character_price = character_price - - def calculate_cost(self, character_count: int) -> Decimal: - """Calculate cost for TTS character usage""" - return Decimal(character_count) * self.character_price - - -class TimePricingModel(PricingModel): - """Pricing model for time-based services (STT)""" - - def __init__(self, second_price: Decimal): - self.second_price = second_price - - def calculate_cost(self, seconds: float) -> Decimal: - """Calculate cost for STT time usage""" - return Decimal(str(seconds)) * self.second_price diff --git a/api/services/pricing/registry.py b/api/services/pricing/registry.py deleted file mode 100644 index 294a94a2..00000000 --- a/api/services/pricing/registry.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Main pricing registry that combines all service type pricing models. -""" - -from typing import Dict - -from .embeddings import EMBEDDINGS_PRICING -from .llm import LLM_PRICING -from .stt import STT_PRICING -from .tts import TTS_PRICING - -# Combined pricing registry for all service types -PRICING_REGISTRY: Dict = { - "llm": LLM_PRICING, - "tts": TTS_PRICING, - "stt": STT_PRICING, - "embeddings": EMBEDDINGS_PRICING, -} diff --git a/api/services/pricing/run_usage_response.py b/api/services/pricing/run_usage_response.py deleted file mode 100644 index a1f85a47..00000000 --- a/api/services/pricing/run_usage_response.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Format workflow run usage for public API responses.""" - - -def format_public_usage_info(usage_info: dict | None) -> dict | None: - if not usage_info: - return None - - return { - "llm": usage_info.get("llm") or {}, - "tts": usage_info.get("tts") or {}, - "stt": usage_info.get("stt") or {}, - "call_duration_seconds": usage_info.get("call_duration_seconds"), - } diff --git a/api/services/pricing/stt.py b/api/services/pricing/stt.py deleted file mode 100644 index ca00ff4c..00000000 --- a/api/services/pricing/stt.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -STT (Speech-to-Text) pricing models for different providers. - -Prices are per second for STT services. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import TimePricingModel - -# STT pricing registry -STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = { - ServiceProviders.DEEPGRAM: { - "nova-3-general": TimePricingModel(Decimal("0.0077") / 60), - "nova-2": TimePricingModel(Decimal("0.0058") / 60), - "default": TimePricingModel(Decimal("0.0077") / 60), - }, - ServiceProviders.OPENAI: { - "gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60), - "default": TimePricingModel(Decimal("0.015") / 60), - }, - "default": {"default": TimePricingModel(Decimal("0.0077") / 60)}, -} diff --git a/api/services/pricing/tts.py b/api/services/pricing/tts.py deleted file mode 100644 index 7485cc7f..00000000 --- a/api/services/pricing/tts.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -TTS (Text-to-Speech) pricing models for different providers. - -Prices are per character for TTS services. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import CharacterPricingModel - -# TTS pricing registry -TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = { - ServiceProviders.OPENAI: { - "gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000), - "default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000), - }, - ServiceProviders.DEEPGRAM: { - "aura-2": CharacterPricingModel(Decimal("0.030") / 1_000), - "aura-1": CharacterPricingModel(Decimal("0.015") / 1_000), - "default": CharacterPricingModel(Decimal("0.030") / 1_000), - }, - ServiceProviders.ELEVENLABS: { - # 6400 usd per 250*1e6 characters - "default": CharacterPricingModel(Decimal("0.0256") / 1_000) - }, - "default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)}, -} diff --git a/api/services/pricing/workflow_run_cost.py b/api/services/pricing/workflow_run_cost.py deleted file mode 100644 index 6d6010c3..00000000 --- a/api/services/pricing/workflow_run_cost.py +++ /dev/null @@ -1,230 +0,0 @@ -from decimal import Decimal - -from loguru import logger - -from api.db import db_client -from api.enums import WorkflowRunMode -from api.services.pricing.cost_calculator import cost_calculator -from api.services.telephony.factory import get_telephony_provider_for_run - - -async def _fetch_telephony_cost(workflow_run) -> dict | None: - """Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None.""" - if ( - workflow_run.mode - not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] - or not workflow_run.cost_info - ): - return None - - call_id = workflow_run.cost_info.get("call_id") - if not call_id: - logger.warning(f"call_id not found in cost_info") - return None - - provider_name = workflow_run.mode.lower() if workflow_run.mode else "" - - workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) - if not workflow: - logger.warning("Workflow not found for workflow run") - raise Exception("Workflow not found") - - provider = await get_telephony_provider_for_run( - workflow_run, workflow.organization_id - ) - call_cost_info = await provider.get_call_cost(call_id) - - if call_cost_info.get("status") == "error": - logger.error( - f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}" - ) - return None - - cost_usd = call_cost_info.get("cost_usd", 0.0) - logger.info( - f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}" - ) - return {"cost_usd": cost_usd, "provider_name": provider_name} - - -async def _update_organization_usage( - org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None -) -> None: - """Update organization usage after a workflow run.""" - org_id = org.id - await db_client.update_usage_after_run( - org_id, dograh_tokens, duration_seconds, charge_usd - ) - if charge_usd is not None: - logger.info( - f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}" - ) - else: - logger.info( - f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}" - ) - - -async def _get_pricing_organization(workflow_run): - workflow = getattr(workflow_run, "workflow", None) - organization_id = getattr(workflow, "organization_id", None) - if organization_id is None and workflow and workflow.user: - organization_id = workflow.user.selected_organization_id - if organization_id is None: - return None - return await db_client.get_organization_by_id(organization_id) - - -async def _build_usage_cost_snapshot( - usage_info: dict | None, - *, - workflow_run=None, - include_telephony_cost: bool = False, - organization=None, - calculated_at: str | None = None, -) -> dict | None: - if not usage_info: - logger.warning("No usage info available for workflow run") - return None - - cost_breakdown = cost_calculator.calculate_total_cost(usage_info) - - if include_telephony_cost and workflow_run is not None: - try: - telephony_cost = await _fetch_telephony_cost(workflow_run) - if telephony_cost: - telephony_cost_usd = telephony_cost["cost_usd"] - provider_name = telephony_cost["provider_name"] - cost_breakdown["telephony_call"] = telephony_cost_usd - cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd - cost_breakdown["total"] = ( - float(cost_breakdown["total"]) + telephony_cost_usd - ) - except Exception as e: - logger.error(f"Failed to fetch telephony call cost: {e}") - # Don't fail the whole cost calculation if telephony API fails - - total_cost_usd = Decimal(str(cost_breakdown["total"])) - dograh_tokens = float(total_cost_usd * Decimal("100")) - - if organization is None and workflow_run is not None: - organization = await _get_pricing_organization(workflow_run) - - charge_usd = None - if organization and organization.price_per_second_usd: - duration_seconds = usage_info.get("call_duration_seconds", 0) - charge_usd = float( - Decimal(str(duration_seconds)) - * Decimal(str(organization.price_per_second_usd)) - ) - - cost_info = { - "cost_breakdown": cost_breakdown, - "total_cost_usd": float(total_cost_usd), - "dograh_token_usage": dograh_tokens, - "calculated_at": calculated_at - or (workflow_run.created_at.isoformat() if workflow_run is not None else None), - "call_duration_seconds": usage_info.get("call_duration_seconds", 0), - } - - if charge_usd is not None: - cost_info["charge_usd"] = charge_usd - cost_info["price_per_second_usd"] = organization.price_per_second_usd - - return cost_info - - -async def build_workflow_run_cost_info(workflow_run) -> dict | None: - cost_info = await _build_usage_cost_snapshot( - workflow_run.usage_info, - workflow_run=workflow_run, - include_telephony_cost=True, - calculated_at=workflow_run.created_at.isoformat(), - ) - if cost_info is None: - return None - return { - **(workflow_run.cost_info or {}), - **cost_info, - } - - -async def save_workflow_run_cost_info( - workflow_run_id: int, cost_info: dict | None -) -> None: - if cost_info is None: - return - await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info) - - -async def apply_workflow_run_usage_to_organization( - workflow_run, cost_info: dict | None -) -> None: - if cost_info is None: - return - - org = await _get_pricing_organization(workflow_run) - if not org: - return - - await _update_organization_usage( - org, - float(cost_info.get("dograh_token_usage") or 0), - float(cost_info.get("call_duration_seconds") or 0), - cost_info.get("charge_usd"), - ) - - -async def apply_usage_delta_to_organization( - workflow_run, usage_info: dict | None -) -> dict | None: - org = await _get_pricing_organization(workflow_run) - if not org: - return None - - cost_info = await _build_usage_cost_snapshot(usage_info, organization=org) - if cost_info is None: - return None - - await _update_organization_usage( - org, - float(cost_info.get("dograh_token_usage") or 0), - float(cost_info.get("call_duration_seconds") or 0), - cost_info.get("charge_usd"), - ) - return cost_info - - -async def calculate_workflow_run_cost(workflow_run_id: int): - logger.debug("Calculating cost for workflow run") - - workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) - if not workflow_run: - logger.warning("Workflow run not found") - return - - try: - cost_info = await build_workflow_run_cost_info(workflow_run) - if cost_info is None: - return - - await save_workflow_run_cost_info(workflow_run_id, cost_info) - - try: - await apply_workflow_run_usage_to_organization(workflow_run, cost_info) - except Exception as e: - org = await _get_pricing_organization(workflow_run) - if org: - logger.error( - f"Failed to update organization usage for org {org.id}: {e}" - ) - else: - logger.error(f"Failed to update organization usage: {e}") - # Don't fail the whole cost calculation if usage update fails - - logger.info( - f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)" - ) - except Exception as e: - logger.error(f"Error calculating cost for workflow run: {e}") - raise diff --git a/api/services/reports/run_report.py b/api/services/reports/run_report.py index b84a6f96..a5e64819 100644 --- a/api/services/reports/run_report.py +++ b/api/services/reports/run_report.py @@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO: for run in runs: initial = run.initial_context or {} gathered = run.gathered_context or {} - cost = run.cost_info or {} + usage = run.usage_info or {} call_tags = gathered.get("call_tags", []) if isinstance(call_tags, list): @@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO: run.created_at.isoformat() if run.created_at else "", initial.get("phone_number", ""), gathered.get("mapped_call_disposition", ""), - cost.get("call_duration_seconds", ""), + usage.get("call_duration_seconds", ""), ] extracted = gathered.get("extracted_variables", {}) diff --git a/api/services/telephony/providers/vonage/routes.py b/api/services/telephony/providers/vonage/routes.py index a4cca35d..c862e745 100644 --- a/api/services/telephony/providers/vonage/routes.py +++ b/api/services/telephony/providers/vonage/routes.py @@ -66,34 +66,6 @@ async def handle_vonage_events( logger.error(f"[run {workflow_run_id}] Workflow run not found") return {"status": "error", "message": "Workflow run not found"} - # For a completed call that includes cost info, capture it immediately - if event_data.get("status") == "completed": - # Vonage sometimes includes price info in the webhook - if "price" in event_data or "rate" in event_data: - try: - if workflow_run.cost_info: - # Store immediate cost info if available - cost_info = workflow_run.cost_info.copy() - if "price" in event_data: - cost_info["vonage_webhook_price"] = float(event_data["price"]) - if "rate" in event_data: - cost_info["vonage_webhook_rate"] = float(event_data["rate"]) - if "duration" in event_data: - cost_info["vonage_webhook_duration"] = int( - event_data["duration"] - ) - - await db_client.update_workflow_run( - run_id=workflow_run_id, cost_info=cost_info - ) - logger.info( - f"[run {workflow_run_id}] Captured Vonage cost info from webhook" - ) - except Exception as e: - logger.error( - f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}" - ) - # Get workflow and provider workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) if not workflow: diff --git a/api/services/workflow/run_usage_response.py b/api/services/workflow/run_usage_response.py new file mode 100644 index 00000000..c289e565 --- /dev/null +++ b/api/services/workflow/run_usage_response.py @@ -0,0 +1,41 @@ +"""Format workflow run usage for public API responses.""" + + +def format_public_usage_info(usage_info: dict | None) -> dict | None: + if not usage_info: + return None + + return { + "llm": usage_info.get("llm") or {}, + "tts": usage_info.get("tts") or {}, + "stt": usage_info.get("stt") or {}, + "call_duration_seconds": usage_info.get("call_duration_seconds"), + } + + +def format_public_cost_info( + cost_info: dict | None, usage_info: dict | None +) -> dict | None: + """Return the legacy response shape without doing local cost accounting.""" + duration = None + if usage_info and usage_info.get("call_duration_seconds") is not None: + duration = int(round(usage_info.get("call_duration_seconds") or 0)) + elif cost_info and cost_info.get("call_duration_seconds") is not None: + duration = int(round(cost_info.get("call_duration_seconds") or 0)) + + dograh_token_usage = 0 + if cost_info: + if "dograh_token_usage" in cost_info: + dograh_token_usage = cost_info.get("dograh_token_usage") or 0 + elif "total_cost_usd" in cost_info: + dograh_token_usage = round( + float(cost_info.get("total_cost_usd", 0)) * 100, 2 + ) + + if duration is None and dograh_token_usage == 0: + return None + + return { + "dograh_token_usage": dograh_token_usage, + "call_duration_seconds": duration, + } diff --git a/api/services/workflow/text_chat_session_service.py b/api/services/workflow/text_chat_session_service.py index 53354d5f..81749960 100644 --- a/api/services/workflow/text_chat_session_service.py +++ b/api/services/workflow/text_chat_session_service.py @@ -4,17 +4,11 @@ from datetime import UTC, datetime from typing import Any from uuid import uuid4 -from loguru import logger - from api.db import db_client from api.db.models import WorkflowRunTextSessionModel from api.db.workflow_run_text_session_client import ( WorkflowRunTextSessionRevisionConflictError, ) -from api.services.pricing.workflow_run_cost import ( - apply_usage_delta_to_organization, - build_workflow_run_cost_info, -) from api.services.workflow.text_chat_logs import ( build_text_chat_realtime_feedback_events, ) @@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn( state=execution.state, is_completed=execution.is_completed, ) - workflow_run = await db_client.get_workflow_run_by_id(run_id) - if workflow_run: - try: - # Apply the per-turn delta so org usage tracks cumulative run cost - # without replaying the full session totals on every turn. - await apply_usage_delta_to_organization(workflow_run, execution.usage) - except Exception as e: - logger.error( - f"Failed to update organization usage for text chat run {run_id}: {e}" - ) - - cost_info = await build_workflow_run_cost_info(workflow_run) - if cost_info is not None: - await db_client.update_workflow_run(run_id, cost_info=cost_info) return await _reload_text_chat_session(run_id) diff --git a/api/services/workflow_run_billing.py b/api/services/workflow_run_billing.py new file mode 100644 index 00000000..ab8a3121 --- /dev/null +++ b/api/services/workflow_run_billing.py @@ -0,0 +1,111 @@ +"""Workflow-run billing hooks. + +Dograh does not rate or deduct credits locally. MPS owns credit accounting. +For hosted deployments, Dograh reports completed platform usage to MPS. +When a server-minted MPS correlation id exists, MPS uses model-service usage +as the canonical duration. Otherwise Dograh reports the completed run duration. +""" + +from typing import Any + +from loguru import logger + +from api.constants import DEPLOYMENT_MODE +from api.db import db_client +from api.services.managed_model_services import get_mps_correlation_id +from api.services.mps_service_key_client import mps_service_key_client + + +def _workflow_run_organization_id(workflow_run) -> int | None: + workflow = getattr(workflow_run, "workflow", None) + return getattr(workflow, "organization_id", None) + + +def _duration_seconds_from_usage_info(workflow_run) -> float | None: + usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {} + duration = usage_info.get("call_duration_seconds") + try: + duration_seconds = float(duration) + except (TypeError, ValueError): + return None + + return duration_seconds if duration_seconds > 0 else None + + +async def _organization_uses_mps_billing_v2(organization_id: int) -> bool: + account = await mps_service_key_client.get_billing_account_status( + organization_id=organization_id + ) + return bool(account and account.get("billing_mode") == "v2") + + +async def report_workflow_run_platform_usage(workflow_run) -> None: + """Report hosted platform usage for a completed workflow run to MPS.""" + if DEPLOYMENT_MODE == "oss": + return + + if not getattr(workflow_run, "is_completed", False): + return + + organization_id = _workflow_run_organization_id(workflow_run) + if organization_id is None: + logger.warning( + "Skipping platform usage report for workflow run {}: no organization_id", + workflow_run.id, + ) + return + + correlation_id = get_mps_correlation_id( + getattr(workflow_run, "initial_context", None) + ) + duration_seconds = ( + None if correlation_id else _duration_seconds_from_usage_info(workflow_run) + ) + if not correlation_id and duration_seconds is None: + logger.warning( + "Skipping platform usage report for workflow run {}: no billable duration", + workflow_run.id, + ) + return + + try: + if not await _organization_uses_mps_billing_v2(organization_id): + return + + result = await mps_service_key_client.report_platform_usage( + organization_id=organization_id, + correlation_id=correlation_id, + duration_seconds=duration_seconds, + workflow_run_id=workflow_run.id, + metadata={ + "source": "workflow_run_completion", + "workflow_id": getattr(workflow_run, "workflow_id", None), + "duration_source": ( + "mps_correlation" if correlation_id else "dograh_usage_info" + ), + }, + ) + logger.info( + "Reported platform usage for workflow run {} to MPS: {}", + workflow_run.id, + result, + ) + except Exception as e: + logger.error( + "Failed to report platform usage for workflow run {}: {}", + workflow_run.id, + e, + ) + + +async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None: + """Load a completed workflow run and report platform usage to MPS.""" + workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) + if not workflow_run: + logger.warning( + "Skipping platform usage report: workflow run {} not found", + workflow_run_id, + ) + return + + await report_workflow_run_platform_usage(workflow_run) diff --git a/api/tasks/arq.py b/api/tasks/arq.py index a948a578..442114e6 100644 --- a/api/tasks/arq.py +++ b/api/tasks/arq.py @@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import ( ) from api.tasks.knowledge_base_processing import process_knowledge_base_document from api.tasks.run_integrations import run_integrations_post_workflow_run -from api.tasks.s3_upload import ( - process_workflow_completion, - upload_voicemail_audio_to_s3, -) +from api.tasks.s3_upload import upload_voicemail_audio_to_s3 +from api.tasks.workflow_completion import process_workflow_completion class WorkerSettings: diff --git a/api/tasks/s3_upload.py b/api/tasks/s3_upload.py index b2086c09..bbbc8bf4 100644 --- a/api/tasks/s3_upload.py +++ b/api/tasks/s3_upload.py @@ -1,13 +1,9 @@ import os -from typing import Optional from loguru import logger from pipecat.utils.run_context import set_current_run_id -from api.db import db_client -from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost -from api.services.storage import get_current_storage_backend, storage_fs -from api.tasks.run_integrations import run_integrations_post_workflow_run +from api.services.storage import storage_fs async def upload_voicemail_audio_to_s3( @@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3( logger.warning( f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}" ) - - -async def process_workflow_completion( - _ctx, - workflow_run_id: int, - audio_temp_path: Optional[str] = None, - transcript_temp_path: Optional[str] = None, -): - """Process workflow completion: upload artifacts and run integrations. - - This task combines audio upload, transcript upload, and webhook integrations - into a single sequential task to ensure integrations run after uploads complete. - - Args: - _ctx: ARQ context (unused) - workflow_run_id: The workflow run ID - audio_temp_path: Optional path to temp audio file - transcript_temp_path: Optional path to temp transcript file - """ - run_id = str(workflow_run_id) - set_current_run_id(run_id) - - logger.info(f"Processing workflow completion for run {workflow_run_id}") - - storage_backend = get_current_storage_backend() - - # Step 1: Upload audio if provided - if audio_temp_path: - try: - if os.path.exists(audio_temp_path): - file_size = os.path.getsize(audio_temp_path) - logger.debug(f"Audio file size: {file_size} bytes") - - recording_url = f"recordings/{workflow_run_id}.wav" - logger.info( - f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}" - ) - - await storage_fs.aupload_file(audio_temp_path, recording_url) - await db_client.update_workflow_run( - run_id=workflow_run_id, - recording_url=recording_url, - storage_backend=storage_backend.value, - ) - logger.info(f"Successfully uploaded audio: {recording_url}") - else: - logger.warning(f"Audio temp file not found: {audio_temp_path}") - except Exception as e: - logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}") - finally: - if audio_temp_path and os.path.exists(audio_temp_path): - try: - os.remove(audio_temp_path) - logger.debug(f"Cleaned up temp audio file: {audio_temp_path}") - except Exception as e: - logger.warning(f"Failed to clean up temp audio file: {e}") - - # Step 2: Upload transcript if provided - if transcript_temp_path: - try: - if os.path.exists(transcript_temp_path): - file_size = os.path.getsize(transcript_temp_path) - logger.debug(f"Transcript file size: {file_size} bytes") - - transcript_url = f"transcripts/{workflow_run_id}.txt" - logger.info( - f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}" - ) - - await storage_fs.aupload_file(transcript_temp_path, transcript_url) - await db_client.update_workflow_run( - run_id=workflow_run_id, - transcript_url=transcript_url, - storage_backend=storage_backend.value, - ) - logger.info(f"Successfully uploaded transcript: {transcript_url}") - else: - logger.warning( - f"Transcript temp file not found: {transcript_temp_path}" - ) - except Exception as e: - logger.error( - f"Error uploading transcript for workflow {workflow_run_id}: {e}" - ) - finally: - if transcript_temp_path and os.path.exists(transcript_temp_path): - try: - os.remove(transcript_temp_path) - logger.debug( - f"Cleaned up temp transcript file: {transcript_temp_path}" - ) - except Exception as e: - logger.warning(f"Failed to clean up temp transcript file: {e}") - - # Step 3: Run integrations including QA analysis (after uploads are complete) - try: - await run_integrations_post_workflow_run(_ctx, workflow_run_id) - except Exception as e: - logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}") - - # Step 4: Calculate cost after integrations (so QA token usage is included) - try: - await calculate_workflow_run_cost(workflow_run_id) - except Exception as e: - logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}") - - logger.info(f"Completed workflow completion processing for run {workflow_run_id}") diff --git a/api/tasks/workflow_completion.py b/api/tasks/workflow_completion.py new file mode 100644 index 00000000..ff0482d2 --- /dev/null +++ b/api/tasks/workflow_completion.py @@ -0,0 +1,121 @@ +import os +from typing import Optional + +from loguru import logger +from pipecat.utils.run_context import set_current_run_id + +from api.db import db_client +from api.services.storage import get_current_storage_backend, storage_fs +from api.services.workflow_run_billing import ( + report_completed_workflow_run_platform_usage, +) +from api.tasks.run_integrations import run_integrations_post_workflow_run + + +async def process_workflow_completion( + _ctx, + workflow_run_id: int, + audio_temp_path: Optional[str] = None, + transcript_temp_path: Optional[str] = None, +): + """Process workflow completion: upload artifacts and run integrations. + + This task combines audio upload, transcript upload, and webhook integrations + into a single sequential task to ensure integrations run after uploads complete. + + Args: + _ctx: ARQ context (unused) + workflow_run_id: The workflow run ID + audio_temp_path: Optional path to temp audio file + transcript_temp_path: Optional path to temp transcript file + """ + run_id = str(workflow_run_id) + set_current_run_id(run_id) + + logger.info(f"Processing workflow completion for run {workflow_run_id}") + + storage_backend = get_current_storage_backend() + + # Step 1: Upload audio if provided + if audio_temp_path: + try: + if os.path.exists(audio_temp_path): + file_size = os.path.getsize(audio_temp_path) + logger.debug(f"Audio file size: {file_size} bytes") + + recording_url = f"recordings/{workflow_run_id}.wav" + logger.info( + f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}" + ) + + await storage_fs.aupload_file(audio_temp_path, recording_url) + await db_client.update_workflow_run( + run_id=workflow_run_id, + recording_url=recording_url, + storage_backend=storage_backend.value, + ) + logger.info(f"Successfully uploaded audio: {recording_url}") + else: + logger.warning(f"Audio temp file not found: {audio_temp_path}") + except Exception as e: + logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}") + finally: + if audio_temp_path and os.path.exists(audio_temp_path): + try: + os.remove(audio_temp_path) + logger.debug(f"Cleaned up temp audio file: {audio_temp_path}") + except Exception as e: + logger.warning(f"Failed to clean up temp audio file: {e}") + + # Step 2: Upload transcript if provided + if transcript_temp_path: + try: + if os.path.exists(transcript_temp_path): + file_size = os.path.getsize(transcript_temp_path) + logger.debug(f"Transcript file size: {file_size} bytes") + + transcript_url = f"transcripts/{workflow_run_id}.txt" + logger.info( + f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}" + ) + + await storage_fs.aupload_file(transcript_temp_path, transcript_url) + await db_client.update_workflow_run( + run_id=workflow_run_id, + transcript_url=transcript_url, + storage_backend=storage_backend.value, + ) + logger.info(f"Successfully uploaded transcript: {transcript_url}") + else: + logger.warning( + f"Transcript temp file not found: {transcript_temp_path}" + ) + except Exception as e: + logger.error( + f"Error uploading transcript for workflow {workflow_run_id}: {e}" + ) + finally: + if transcript_temp_path and os.path.exists(transcript_temp_path): + try: + os.remove(transcript_temp_path) + logger.debug( + f"Cleaned up temp transcript file: {transcript_temp_path}" + ) + except Exception as e: + logger.warning(f"Failed to clean up temp transcript file: {e}") + + # Step 3: Run integrations including QA analysis (after uploads are complete) + try: + await run_integrations_post_workflow_run(_ctx, workflow_run_id) + except Exception as e: + logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}") + + # Step 4: Notify MPS after completion. MPS owns credit accounting. + try: + await report_completed_workflow_run_platform_usage(workflow_run_id) + except Exception as e: + logger.error( + f"Error reporting platform usage for workflow {workflow_run_id}: {e}" + ) + + logger.info(f"Completed workflow completion processing for run {workflow_run_id}") diff --git a/api/tests/test_ai_model_configuration_v2.py b/api/tests/test_ai_model_configuration_v2.py index 71772b28..16056b57 100644 --- a/api/tests/test_ai_model_configuration_v2.py +++ b/api/tests/test_ai_model_configuration_v2.py @@ -15,6 +15,7 @@ from api.services.configuration.ai_model_configuration import ( merge_ai_model_configuration_v2_secrets, migrate_workflow_configuration_model_override_to_v2, ) +from api.services.configuration.check_validity import UserConfigurationValidator from api.services.configuration.masking import mask_key from api.services.configuration.registry import ( DeepgramSTTConfiguration, @@ -22,6 +23,8 @@ from api.services.configuration.registry import ( DograhSTTService, DograhTTSService, ElevenlabsTTSConfiguration, + GoogleLLMService, + GoogleRealtimeLLMConfiguration, OpenAIEmbeddingsConfiguration, OpenAILLMService, ) @@ -93,6 +96,67 @@ def test_byok_v2_rejects_dograh_provider(): ) +@pytest.mark.asyncio +async def test_byok_realtime_validator_does_not_require_stt_or_tts(): + config = OrganizationAIModelConfigurationV2.model_validate( + { + "mode": "byok", + "byok": { + "mode": "realtime", + "realtime": { + "realtime": { + "provider": "google_realtime", + "api_key": "google-realtime-key", + "model": "gemini-3.1-flash-live-preview", + "voice": "Puck", + "language": "en", + }, + "llm": { + "provider": "google", + "api_key": "google-llm-key", + "model": "gemini-2.0-flash", + }, + }, + }, + } + ) + effective = compile_ai_model_configuration_v2(config) + + assert effective.is_realtime is True + assert effective.stt is None + assert effective.tts is None + assert await UserConfigurationValidator().validate(effective) == { + "status": [{"model": "all", "message": "ok"}] + } + + +@pytest.mark.asyncio +async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime(): + effective = EffectiveAIModelConfiguration( + llm=GoogleLLMService( + provider="google", + api_key="google-llm-key", + model="gemini-2.0-flash", + ), + realtime=GoogleRealtimeLLMConfiguration( + provider="google_realtime", + api_key="google-realtime-key", + model="gemini-3.1-flash-live-preview", + voice="Puck", + language="en", + ), + is_realtime=False, + ) + + with pytest.raises(ValueError) as exc_info: + await UserConfigurationValidator().validate(effective) + + assert exc_info.value.args[0] == [ + {"model": "stt", "message": "API key is missing"}, + {"model": "tts", "message": "API key is missing"}, + ] + + def test_masked_dograh_key_is_preserved_when_saving_same_mode(): existing = OrganizationAIModelConfigurationV2( mode="dograh", diff --git a/api/tests/test_cost_calculator.py b/api/tests/test_cost_calculator.py deleted file mode 100644 index 940ac582..00000000 --- a/api/tests/test_cost_calculator.py +++ /dev/null @@ -1,31 +0,0 @@ -from api.services.pricing.cost_calculator import cost_calculator - - -def test_cost_calculator(): - """Test function to verify cost calculation works""" - sample_usage = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 45380, - "completion_tokens": 496, - "total_tokens": 45876, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399}, - "stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692}, - "call_duration_seconds": 179, - } - - result = cost_calculator.calculate_total_cost(sample_usage) - assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000 - assert result["tts_cost"] == 2399 * 0.0256 / 1_000 - assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077 - assert ( - abs( - result["total"] - - (result["llm_cost"] + result["tts_cost"] + result["stt_cost"]) - ) - < 1e-10 - ) diff --git a/api/tests/test_mps_service_key_client.py b/api/tests/test_mps_service_key_client.py index 7f42f13d..e44599d5 100644 --- a/api/tests/test_mps_service_key_client.py +++ b/api/tests/test_mps_service_key_client.py @@ -128,3 +128,152 @@ async def test_create_correlation_id_uses_bearer_auth(monkeypatch): }, ) ] + + +@pytest.mark.asyncio +async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def get(self, url, headers): + calls.append(("GET", url, headers)) + return _Response(200, {"organization_id": 42, "billing_mode": "v2"}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.get_billing_account_status(organization_id=42) == { + "organization_id": 42, + "billing_mode": "v2", + } + assert calls == [ + ( + "GET", + f"{client.base_url}/api/v1/billing/accounts/42/status", + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response(200, {"metered": True}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.report_platform_usage( + organization_id=42, + correlation_id="mps-corr-123", + workflow_run_id=123, + metadata={"source": "workflow_run_completion"}, + ) == {"metered": True} + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/billing/accounts/42/platform-usage", + { + "correlation_id": "mps-corr-123", + "workflow_run_id": 123, + "metadata": {"source": "workflow_run_completion"}, + }, + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response(200, {"metered": True}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.report_platform_usage( + organization_id=42, + duration_seconds=87.0, + workflow_run_id=123, + metadata={"source": "workflow_run_completion"}, + ) == {"metered": True} + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/billing/accounts/42/platform-usage", + { + "duration_seconds": 87.0, + "workflow_run_id": 123, + "metadata": {"source": "workflow_run_completion"}, + }, + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] diff --git a/api/tests/test_organization_usage_billing.py b/api/tests/test_organization_usage_billing.py new file mode 100644 index 00000000..f1fb3819 --- /dev/null +++ b/api/tests/test_organization_usage_billing.py @@ -0,0 +1,33 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from api.routes import organization_usage + + +def test_is_mps_billing_v2_depends_only_on_account_mode(): + assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True + assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False + assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False + assert organization_usage._is_mps_billing_v2(None) is False + + +@pytest.mark.asyncio +async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch): + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + monkeypatch.setattr( + organization_usage.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + + user = SimpleNamespace(provider_id="provider-123") + + assert await organization_usage._get_mps_billing_account_status(user, 42) == { + "billing_mode": "v2" + } + get_status.assert_awaited_once_with( + organization_id=42, + created_by="provider-123", + ) diff --git a/api/tests/test_run_usage_response.py b/api/tests/test_run_usage_response.py index c17d4a9f..044c6563 100644 --- a/api/tests/test_run_usage_response.py +++ b/api/tests/test_run_usage_response.py @@ -1,4 +1,4 @@ -from api.services.pricing.run_usage_response import format_public_usage_info +from api.services.workflow.run_usage_response import format_public_usage_info def test_format_public_usage_info(): diff --git a/api/tests/test_workflow_run_billing.py b/api/tests/test_workflow_run_billing.py new file mode 100644 index 00000000..2837317f --- /dev/null +++ b/api/tests/test_workflow_run_billing.py @@ -0,0 +1,212 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from api.services import workflow_run_billing as workflow_run_billing_mod +from api.services.workflow_run_billing import ( + report_completed_workflow_run_platform_usage, + report_workflow_run_platform_usage, +) + + +def _make_workflow_run(): + return SimpleNamespace( + id=123, + workflow_id=456, + is_completed=True, + initial_context={"mps_correlation_id": "mps-corr-123"}, + usage_info={"call_duration_seconds": 87}, + workflow=SimpleNamespace( + organization_id=42, + user=SimpleNamespace(selected_organization_id=42), + ), + ) + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_reports_hosted_completion( + monkeypatch, +): + workflow_run = _make_workflow_run() + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock(return_value={"metered": True}) + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_awaited_once_with( + organization_id=42, + correlation_id="mps-corr-123", + duration_seconds=None, + workflow_run_id=workflow_run.id, + metadata={ + "source": "workflow_run_completion", + "workflow_id": workflow_run.workflow_id, + "duration_source": "mps_correlation", + }, + ) + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_reports_duration_without_correlation( + monkeypatch, +): + workflow_run = _make_workflow_run() + workflow_run.initial_context = {} + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock(return_value={"metered": True}) + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_awaited_once_with( + organization_id=42, + correlation_id=None, + duration_seconds=87.0, + workflow_run_id=workflow_run.id, + metadata={ + "source": "workflow_run_completion", + "workflow_id": workflow_run.workflow_id, + "duration_source": "dograh_usage_info", + }, + ) + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch): + workflow_run = _make_workflow_run() + get_status = AsyncMock(return_value={"billing_mode": "v1"}) + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + get_status.assert_awaited_once_with(organization_id=42) + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation( + monkeypatch, +): + workflow_run = _make_workflow_run() + workflow_run.initial_context = {} + workflow_run.usage_info = {} + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + get_status.assert_not_awaited() + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch): + workflow_run = _make_workflow_run() + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch): + workflow_run = _make_workflow_run() + workflow_run.is_completed = False + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch): + workflow_run = _make_workflow_run() + get_run = AsyncMock(return_value=workflow_run) + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock(return_value={"metered": True}) + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.db_client, + "get_workflow_run_by_id", + get_run, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_completed_workflow_run_platform_usage(workflow_run.id) + + get_run.assert_awaited_once_with(workflow_run.id) + report_usage.assert_awaited_once() diff --git a/api/tests/test_workflow_run_cost.py b/api/tests/test_workflow_run_cost.py deleted file mode 100644 index c77424c8..00000000 --- a/api/tests/test_workflow_run_cost.py +++ /dev/null @@ -1,181 +0,0 @@ -from datetime import UTC, datetime -from types import SimpleNamespace -from unittest.mock import AsyncMock - -import pytest - -from api.services.pricing import workflow_run_cost as workflow_run_cost_mod -from api.services.pricing.workflow_run_cost import ( - apply_usage_delta_to_organization, - build_workflow_run_cost_info, - calculate_workflow_run_cost, -) - - -def _make_workflow_run(): - return SimpleNamespace( - id=123, - workflow_id=456, - mode="textchat", - created_at=datetime.now(UTC), - usage_info={ - "llm": {}, - "tts": {}, - "stt": {}, - "call_duration_seconds": 7, - }, - cost_info={}, - workflow=SimpleNamespace( - organization_id=42, - user=SimpleNamespace(selected_organization_id=42), - ), - ) - - -@pytest.mark.asyncio -async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch): - workflow_run = _make_workflow_run() - get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5)) - update_usage = AsyncMock() - - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "get_organization_by_id", get_org - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage - ) - - cost_info = await build_workflow_run_cost_info(workflow_run) - - assert cost_info is not None - assert cost_info["call_duration_seconds"] == 7 - assert "cost_breakdown" in cost_info - assert "dograh_token_usage" in cost_info - assert cost_info["charge_usd"] == 10.5 - update_usage.assert_not_called() - - -@pytest.mark.asyncio -async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper( - monkeypatch, -): - workflow_run = _make_workflow_run() - get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None)) - update_run = AsyncMock() - update_usage = AsyncMock() - - monkeypatch.setattr( - workflow_run_cost_mod.db_client, - "get_workflow_run_by_id", - AsyncMock(return_value=workflow_run), - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "get_organization_by_id", get_org - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_workflow_run", update_run - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage - ) - - await calculate_workflow_run_cost(workflow_run.id) - - update_run.assert_awaited_once() - saved_kwargs = update_run.await_args.kwargs - assert saved_kwargs["run_id"] == workflow_run.id - assert "cost_breakdown" in saved_kwargs["cost_info"] - update_usage.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_apply_usage_delta_to_organization_uses_incremental_costs( - monkeypatch, -): - workflow_run = _make_workflow_run() - workflow_run.cost_info = {"call_id": "preserve-me"} - - usage_delta_one = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 1_000, - "completion_tokens": 100, - "total_tokens": 1_100, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {}, - "stt": {}, - "call_duration_seconds": 3, - } - usage_delta_two = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 2_000, - "completion_tokens": 50, - "total_tokens": 2_050, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {}, - "stt": {}, - "call_duration_seconds": 4, - } - merged_usage = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 3_000, - "completion_tokens": 150, - "total_tokens": 3_150, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {}, - "stt": {}, - "call_duration_seconds": 7, - } - - get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5)) - update_usage = AsyncMock() - - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "get_organization_by_id", get_org - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage - ) - - first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one) - second_delta = await apply_usage_delta_to_organization( - workflow_run, usage_delta_two - ) - total_workflow_run = SimpleNamespace(**workflow_run.__dict__) - total_workflow_run.usage_info = merged_usage - total_cost = await build_workflow_run_cost_info(total_workflow_run) - - assert first_delta is not None - assert second_delta is not None - assert total_cost is not None - assert update_usage.await_count == 2 - assert update_usage.await_args_list[0].args == ( - 42, - first_delta["dograh_token_usage"], - 3.0, - first_delta["charge_usd"], - ) - assert update_usage.await_args_list[1].args == ( - 42, - second_delta["dograh_token_usage"], - 4.0, - second_delta["charge_usd"], - ) - assert ( - first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"] - ) == pytest.approx(total_cost["dograh_token_usage"]) - assert ( - first_delta["charge_usd"] + second_delta["charge_usd"] - == total_cost["charge_usd"] - ) diff --git a/api/tests/test_workflow_text_chat.py b/api/tests/test_workflow_text_chat.py index b3fb0d86..3be8a613 100644 --- a/api/tests/test_workflow_text_chat.py +++ b/api/tests/test_workflow_text_chat.py @@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn( assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", []) workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"]) assert workflow_run is not None - assert workflow_run.cost_info[ - "call_duration_seconds" - ] == workflow_run.usage_info.get("call_duration_seconds", 0) - assert "cost_breakdown" in workflow_run.cost_info - assert "dograh_token_usage" in workflow_run.cost_info + assert "call_duration_seconds" in workflow_run.usage_info assert _log_texts(run_payload["logs"], "rtf-bot-text") == [ "Hello from the workflow tester." ] @@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn( assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", []) workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"]) assert workflow_run is not None - assert workflow_run.cost_info[ - "call_duration_seconds" - ] == workflow_run.usage_info.get("call_duration_seconds", 0) - assert "cost_breakdown" in workflow_run.cost_info - assert "dograh_token_usage" in workflow_run.cost_info + assert "call_duration_seconds" in workflow_run.usage_info assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"] assert _log_texts(run_payload["logs"], "rtf-bot-text") == [ "Welcome to the workflow tester.", diff --git a/ui/src/app/billing/page.tsx b/ui/src/app/billing/page.tsx index 5f1b4572..03f14d98 100644 --- a/ui/src/app/billing/page.tsx +++ b/ui/src/app/billing/page.tsx @@ -1,11 +1,13 @@ "use client"; import { CircleDollarSign, CreditCard, RefreshCw } from "lucide-react"; +import Link from "next/link"; import { useCallback, useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; import { createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost, getBillingCreditsApiV1OrganizationsBillingCreditsGet } from "@/client/sdk.gen"; -import type { MpsBillingCreditsResponse } from "@/client/types.gen"; +import type { MpsBillingCreditsResponse, MpsCreditLedgerEntryResponse } from "@/client/types.gen"; +import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Progress } from "@/components/ui/progress"; @@ -19,7 +21,6 @@ import { TableRow, } from "@/components/ui/table"; import { useAppConfig } from "@/context/AppConfigContext"; -import { useOrgConfig } from "@/context/OrgConfigContext"; import { useAuth } from "@/lib/auth"; const formatCredits = (value: number | null | undefined) => ( @@ -50,18 +51,58 @@ const formatDate = (value: string) => ( }) ); +const metricLabels: Record = { + voice_minutes: "Voice usage", + platform_usage: "Platform usage", +}; + +const formatTitleCase = (value: string | null | undefined) => ( + value ? value.replaceAll("_", " ").replace(/\b\w/g, (letter) => letter.toUpperCase()) : "-" +); + +const getLedgerEntryLabel = (entry: MpsCreditLedgerEntryResponse) => { + if (entry.metric_code) { + return metricLabels[entry.metric_code] ?? formatTitleCase(entry.metric_code); + } + + if (entry.entry_type === "grant") { + return "Credit grant"; + } + + if (entry.entry_type === "purchase") { + return "Credit purchase"; + } + + return formatTitleCase(entry.entry_type); +}; + +const formatBillableQuantity = (entry: MpsCreditLedgerEntryResponse) => { + if (entry.billable_quantity == null || !entry.quantity_unit) { + return null; + } + + const unit = entry.quantity_unit === "minute" ? "min" : entry.quantity_unit; + return `${formatCredits(entry.billable_quantity)} ${unit}`; +}; + +const getRunHref = (entry: MpsCreditLedgerEntryResponse) => { + if (!entry.workflow_id || !entry.workflow_run_id) { + return null; + } + + return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`; +}; + export default function BillingPage() { const auth = useAuth(); const { config } = useAppConfig(); - const { orgContext, loading: orgConfigLoading } = useOrgConfig(); const [credits, setCredits] = useState(null); const [loading, setLoading] = useState(true); const [refreshing, setRefreshing] = useState(false); const [purchasing, setPurchasing] = useState(false); - const isManagedServiceV2 = Boolean(orgContext?.model_services.uses_managed_service_v2); - const isBillingV2 = isManagedServiceV2 && credits?.billing_version === "v2"; - const canPurchaseCredits = isManagedServiceV2 && config?.deploymentMode !== "oss"; + const isBillingV2 = credits?.billing_version === "v2"; + const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss"; const totalQuota = credits?.total_quota ?? 0; const remainingCredits = credits?.remaining_credits ?? 0; const usedCredits = credits?.total_credits_used ?? 0; @@ -132,7 +173,7 @@ export default function BillingPage() { } }; - if (loading || orgConfigLoading) { + if (loading) { return (
@@ -206,13 +247,14 @@ export default function BillingPage() { {ledgerEntries.length > 0 ? ( -
+
Date - Type + Activity Origin + Run Delta Balance Amount @@ -221,11 +263,39 @@ export default function BillingPage() { {ledgerEntries.map((entry) => { const delta = entry.credits_delta ?? 0; + const runHref = getRunHref(entry); + const billableQuantity = formatBillableQuantity(entry); return ( {formatDate(entry.created_at)} - {entry.entry_type.replaceAll("_", " ")} - {entry.origin || "-"} + +
+ {getLedgerEntryLabel(entry)} + {billableQuantity && ( + {billableQuantity} + )} +
+
+ + {entry.origin ? ( + {formatTitleCase(entry.origin)} + ) : ( + "-" + )} + + + {entry.workflow_run_id ? ( + runHref ? ( + + #{entry.workflow_run_id} + + ) : ( + #{entry.workflow_run_id} + ) + ) : ( + "-" + )} + = 0 ? "text-green-600" : "text-destructive"}`}> {delta >= 0 ? "+" : ""} {formatCredits(delta)} @@ -251,7 +321,6 @@ export default function BillingPage() { Credit Usage - Current legacy MPS credit allocation. diff --git a/ui/src/client/types.gen.ts b/ui/src/client/types.gen.ts index a0dc5b72..b20a0715 100644 --- a/ui/src/client/types.gen.ts +++ b/ui/src/client/types.gen.ts @@ -3176,6 +3176,38 @@ export type MpsCreditLedgerEntryResponse = { * Payment Order Id */ payment_order_id?: number | null; + /** + * Metric Code + */ + metric_code?: string | null; + /** + * Correlation Id + */ + correlation_id?: string | null; + /** + * Aggregation Key + */ + aggregation_key?: string | null; + /** + * Usage Event Id + */ + usage_event_id?: number | null; + /** + * Workflow Run Id + */ + workflow_run_id?: number | null; + /** + * Workflow Id + */ + workflow_id?: number | null; + /** + * Billable Quantity + */ + billable_quantity?: number | null; + /** + * Quantity Unit + */ + quantity_unit?: string | null; /** * Metadata */ diff --git a/ui/src/lib/apiError.ts b/ui/src/lib/apiError.ts index aea5049b..6d4338c5 100644 --- a/ui/src/lib/apiError.ts +++ b/ui/src/lib/apiError.ts @@ -2,18 +2,33 @@ * Extract a human-readable message from a backend error response. * * The generated API client returns `{ error }` on failure (it does not throw), - * and FastAPI shapes that error as either `{ detail: string }` (HTTPException) - * or `{ detail: [{ msg, loc, ... }] }` (422 validation). This normalizes both - * to a single string so it can be rendered or thrown directly — never pass the - * raw `detail` to React, as the 422 array crashes rendering. + * and FastAPI shapes that error as `{ detail: string }`, `{ detail: + * [{ msg, loc, ... }] }`, or backend validation arrays like `{ detail: + * [{ model, message }] }`. This normalizes those to a single string so it can + * be rendered or thrown directly. */ export function detailFromError(err: unknown, fallback = "Request failed"): string { if (typeof err === "string") return err; const e = err as { detail?: unknown }; if (typeof e?.detail === "string") return e.detail; if (Array.isArray(e?.detail) && e.detail.length > 0) { - const first = e.detail[0] as { msg?: string }; - if (first?.msg) return first.msg; + const messages = e.detail + .map((item) => { + if (typeof item === "string") return item; + if (!item || typeof item !== "object") return null; + const detail = item as { message?: unknown; msg?: unknown; model?: unknown }; + const message = typeof detail.message === "string" + ? detail.message + : typeof detail.msg === "string" + ? detail.msg + : null; + if (!message) return null; + return typeof detail.model === "string" && detail.model + ? `${detail.model}: ${message}` + : message; + }) + .filter((message): message is string => Boolean(message)); + if (messages.length > 0) return messages.join("\n"); } return fallback; }