"""Workflow-run billing hooks. Dograh does not rate or deduct credits locally. MPS owns credit accounting. For hosted deployments, Dograh reports completed platform usage to MPS. When a server-minted MPS correlation id exists, MPS uses model-service usage as the canonical duration. Otherwise Dograh reports the completed run duration. """ from typing import Any from loguru import logger from api.constants import DEPLOYMENT_MODE from api.db import db_client from api.services.managed_model_services import get_mps_correlation_id from api.services.mps_service_key_client import mps_service_key_client def _workflow_run_organization_id(workflow_run) -> int | None: workflow = getattr(workflow_run, "workflow", None) return getattr(workflow, "organization_id", None) def _duration_seconds_from_usage_info(workflow_run) -> float | None: usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {} duration = usage_info.get("call_duration_seconds") try: duration_seconds = float(duration) except (TypeError, ValueError): return None return duration_seconds if duration_seconds > 0 else None async def _organization_uses_mps_billing_v2(organization_id: int) -> bool: account = await mps_service_key_client.get_billing_account_status( organization_id=organization_id ) return bool(account and account.get("billing_mode") == "v2") async def report_workflow_run_platform_usage(workflow_run) -> None: """Report hosted platform usage for a completed workflow run to MPS.""" if DEPLOYMENT_MODE == "oss": return if not getattr(workflow_run, "is_completed", False): return organization_id = _workflow_run_organization_id(workflow_run) if organization_id is None: logger.warning( "Skipping platform usage report for workflow run {}: no organization_id", workflow_run.id, ) return correlation_id = get_mps_correlation_id( getattr(workflow_run, "initial_context", None) ) duration_seconds = ( None if correlation_id else _duration_seconds_from_usage_info(workflow_run) ) if not correlation_id and duration_seconds is None: logger.warning( "Skipping platform usage report for workflow run {}: no billable duration", workflow_run.id, ) return try: if not await _organization_uses_mps_billing_v2(organization_id): return result = await mps_service_key_client.report_platform_usage( organization_id=organization_id, correlation_id=correlation_id, duration_seconds=duration_seconds, workflow_run_id=workflow_run.id, metadata={ "source": "workflow_run_completion", "workflow_id": getattr(workflow_run, "workflow_id", None), "duration_source": ( "mps_correlation" if correlation_id else "dograh_usage_info" ), }, ) logger.info( "Reported platform usage for workflow run {} to MPS: {}", workflow_run.id, result, ) except Exception as e: logger.error( "Failed to report platform usage for workflow run {}: {}", workflow_run.id, e, ) async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None: """Load a completed workflow run and report platform usage to MPS.""" workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) if not workflow_run: logger.warning( "Skipping platform usage report: workflow run {} not found", workflow_run_id, ) return await report_workflow_run_platform_usage(workflow_run)