diff --git a/api/alembic/versions/2159d4ac431a_added_quota_tables.py b/api/alembic/versions/2159d4ac431a_added_quota_tables.py index 51efc4cc..24326e4b 100644 --- a/api/alembic/versions/2159d4ac431a_added_quota_tables.py +++ b/api/alembic/versions/2159d4ac431a_added_quota_tables.py @@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None +DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state." + + def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### # 1) Create the `quota_type` enum *before* we add the column that references it. @@ -34,7 +37,12 @@ def upgrade() -> None: sa.Column("organization_id", sa.Integer(), nullable=False), sa.Column("period_start", sa.DateTime(), nullable=False), sa.Column("period_end", sa.DateTime(), nullable=False), - sa.Column("quota_dograh_tokens", sa.Integer(), nullable=False), + sa.Column( + "quota_dograh_tokens", + sa.Integer(), + nullable=False, + comment=DEPRECATED_QUOTA_COMMENT, + ), sa.Column("used_dograh_tokens", sa.Integer(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), @@ -63,7 +71,11 @@ def upgrade() -> None: op.add_column( "organizations", sa.Column( - "quota_type", quota_type_enum, nullable=False, server_default="monthly" + "quota_type", + quota_type_enum, + nullable=False, + server_default="monthly", + comment=DEPRECATED_QUOTA_COMMENT, ), ) op.add_column( @@ -73,6 +85,7 @@ def upgrade() -> None: sa.Integer(), nullable=False, server_default=sa.text("0"), + comment=DEPRECATED_QUOTA_COMMENT, ), ) op.add_column( @@ -82,10 +95,17 @@ def upgrade() -> None: sa.Integer(), nullable=False, server_default=sa.text("LEAST(EXTRACT(DAY FROM CURRENT_DATE)::int, 28)"), + comment=DEPRECATED_QUOTA_COMMENT, ), ) op.add_column( - "organizations", sa.Column("quota_start_date", sa.DateTime(), nullable=True) + "organizations", + sa.Column( + "quota_start_date", + sa.DateTime(), + nullable=True, + comment=DEPRECATED_QUOTA_COMMENT, + ), ) op.add_column( "organizations", @@ -94,6 +114,7 @@ def upgrade() -> None: sa.Boolean(), nullable=False, server_default=sa.text("false"), + comment=DEPRECATED_QUOTA_COMMENT, ), ) # ### end Alembic commands ### diff --git a/api/alembic/versions/c425d3445750_add_columns_in_usage_table.py b/api/alembic/versions/c425d3445750_add_columns_in_usage_table.py index 998e7123..cbd9c654 100644 --- a/api/alembic/versions/c425d3445750_add_columns_in_usage_table.py +++ b/api/alembic/versions/c425d3445750_add_columns_in_usage_table.py @@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None +DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state." + + def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( @@ -26,7 +29,12 @@ def upgrade() -> None: ) op.add_column( "organization_usage_cycles", - sa.Column("quota_amount_usd", sa.Float(), nullable=True), + sa.Column( + "quota_amount_usd", + sa.Float(), + nullable=True, + comment=DEPRECATED_QUOTA_COMMENT, + ), ) # ### end Alembic commands ### diff --git a/api/db/campaign_client.py b/api/db/campaign_client.py index d9ff2bae..7729a2ea 100644 --- a/api/db/campaign_client.py +++ b/api/db/campaign_client.py @@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel from api.schemas.workflow import WorkflowRunResponseSchema +from api.services.workflow.run_usage_response import format_public_cost_info class CampaignClient(BaseDBClient): @@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient): "is_completed": run.is_completed, "recording_url": run.recording_url, "transcript_url": run.transcript_url, - "cost_info": { - "dograh_token_usage": ( - run.cost_info.get("dograh_token_usage") - if run.cost_info - and "dograh_token_usage" in run.cost_info - else round( - float(run.cost_info.get("total_cost_usd", 0)) * 100, - 2, - ) - if run.cost_info and "total_cost_usd" in run.cost_info - else 0 - ), - "call_duration_seconds": int( - round(run.cost_info.get("call_duration_seconds") or 0) - ) - if run.cost_info - else None, - } - if run.cost_info - else None, + "cost_info": format_public_cost_info( + run.cost_info, run.usage_info + ), "definition_id": run.definition_id, "initial_context": run.initial_context, "gathered_context": run.gathered_context, @@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient): async with self.async_session() as session: conditions = [ WorkflowRunModel.is_completed.is_(True), - WorkflowRunModel.cost_info["call_duration_seconds"] + WorkflowRunModel.usage_info["call_duration_seconds"] .as_string() .isnot(None), ] @@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient): WorkflowRunModel.initial_context, WorkflowRunModel.gathered_context, WorkflowRunModel.cost_info, + WorkflowRunModel.usage_info, WorkflowRunModel.public_access_token, ) .where(*conditions) diff --git a/api/db/db_client.py b/api/db/db_client.py index de98cf19..15d1c108 100644 --- a/api/db/db_client.py +++ b/api/db/db_client.py @@ -53,7 +53,7 @@ class DBClient( - UserClient: handles user and user configuration operations - OrganizationClient: handles organization operations - OrganizationConfigurationClient: handles organization configuration operations - - OrganizationUsageClient: handles organization usage and quota operations + - OrganizationUsageClient: handles organization usage reporting aggregates - IntegrationClient: handles integration operations - WorkflowTemplateClient: handles workflow template operations - CampaignClient: handles campaign operations diff --git a/api/db/filters.py b/api/db/filters.py index e960d724..cd30b144 100644 --- a/api/db/filters.py +++ b/api/db/filters.py @@ -25,7 +25,7 @@ def get_workflow_run_order_clause( """ # Determine sort column if sort_by == "duration": - sort_column = WorkflowRunModel.cost_info.op("->>")( + sort_column = WorkflowRunModel.usage_info.op("->>")( "call_duration_seconds" ).cast(Float) else: @@ -43,7 +43,7 @@ def get_workflow_run_order_clause( ATTRIBUTE_FIELD_MAPPING = { "dateRange": "created_at", "dispositionCode": "gathered_context.mapped_call_disposition", - "duration": "cost_info.call_duration_seconds", + "duration": "usage_info.call_duration_seconds", "status": "is_completed", "tokenUsage": "cost_info.total_cost_usd", "runId": "id", @@ -208,7 +208,7 @@ def apply_workflow_run_filters( min_val = value.get("min") max_val = value.get("max") - if field == "cost_info.call_duration_seconds": + if field == "usage_info.call_duration_seconds": # Use ->> operator for compatibility with all PostgreSQL versions # (subscript [] only works in PostgreSQL 14+) duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")( diff --git a/api/db/models.py b/api/db/models.py index c61cb03d..696cb6e6 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -97,22 +97,44 @@ class OrganizationModel(Base): provider_id = Column(String, unique=True, index=True, nullable=False) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) - # Quota fields + # Deprecated: MPS owns quota and credit ledger state. quota_type = Column( Enum("monthly", "annual", name="quota_type"), nullable=False, default="monthly", server_default=text("'monthly'::quota_type"), + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, ) quota_dograh_tokens = Column( - Integer, nullable=False, default=0, server_default=text("0") + Integer, + nullable=False, + default=0, + server_default=text("0"), + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, ) quota_reset_day = Column( - Integer, nullable=False, default=1, server_default=text("1") - ) # 1-28, only for monthly - quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual + Integer, + nullable=False, + default=1, + server_default=text("1"), + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, + ) + quota_start_date = Column( + DateTime(timezone=True), + nullable=True, + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, + ) quota_enabled = Column( - Boolean, nullable=False, default=False, server_default=text("false") + Boolean, + nullable=False, + default=False, + server_default=text("false"), + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, ) price_per_second_usd = Column(Float, nullable=True) @@ -593,8 +615,9 @@ class WorkflowRunTextSessionModel(Base): class OrganizationUsageCycleModel(Base): """ - This model is used to track the usage of Dograh tokens for an organization for a given usage - cycle. + This model is used to track reporting aggregates for an organization for a given + usage cycle. Quota fields on this model are deprecated; MPS owns quota and + credit ledger state. """ __tablename__ = "organization_usage_cycles" @@ -603,14 +626,24 @@ class OrganizationUsageCycleModel(Base): organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False) period_start = Column(DateTime(timezone=True), nullable=False) period_end = Column(DateTime(timezone=True), nullable=False) - quota_dograh_tokens = Column(Integer, nullable=False) + quota_dograh_tokens = Column( + Integer, + nullable=False, + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, + ) used_dograh_tokens = Column(Float, nullable=False, default=0) total_duration_seconds = Column( Integer, nullable=False, default=0, server_default=text("0") ) # New USD tracking fields used_amount_usd = Column(Float, nullable=True, default=0) - quota_amount_usd = Column(Float, nullable=True) + quota_amount_usd = Column( + Float, + nullable=True, + comment="Deprecated. MPS owns quota and credit ledger state.", + info={"deprecated": True}, + ) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) updated_at = Column( DateTime(timezone=True), diff --git a/api/db/organization_usage_client.py b/api/db/organization_usage_client.py index f845fc75..dfca0538 100644 --- a/api/db/organization_usage_client.py +++ b/api/db/organization_usage_client.py @@ -19,11 +19,11 @@ from api.db.models import ( WorkflowRunModel, ) from api.enums import OrganizationConfigurationKey -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration class OrganizationUsageClient(BaseDBClient): - """Client for managing organization usage and quota operations.""" + """Client for managing organization usage reporting aggregates.""" async def get_or_create_current_cycle( self, organization_id: int, session=None @@ -49,14 +49,7 @@ class OrganizationUsageClient(BaseDBClient): self, organization_id: int, session, commit: bool ) -> OrganizationUsageCycleModel: """Internal implementation for get_or_create_current_cycle.""" - # Get organization to determine quota type - org_result = await session.execute( - select(OrganizationModel).where(OrganizationModel.id == organization_id) - ) - org = org_result.scalar_one() - - # Calculate current period - period_start, period_end = self._calculate_current_period(org) + period_start, period_end = self._calculate_current_period() # Try to get existing cycle cycle_result = await session.execute( @@ -78,7 +71,8 @@ class OrganizationUsageClient(BaseDBClient): organization_id=organization_id, period_start=period_start, period_end=period_end, - quota_dograh_tokens=org.quota_dograh_tokens, + # Deprecated non-null column retained for historical schema compatibility. + quota_dograh_tokens=0, ) # Handle concurrent inserts gracefully stmt = stmt.on_conflict_do_nothing( @@ -102,95 +96,9 @@ class OrganizationUsageClient(BaseDBClient): ) return cycle_result.scalar_one() - async def check_and_reserve_quota( - self, organization_id: int, estimated_tokens: int = 0 - ) -> bool: - """ - Check if organization has sufficient quota and optionally reserve tokens. - Returns True if quota is available, False otherwise. - - This method is fully atomic and safe for concurrent access from multiple processes. - """ - async with self.async_session() as session: - # Get organization - org_result = await session.execute( - select(OrganizationModel).where(OrganizationModel.id == organization_id) - ) - org = org_result.scalar_one_or_none() - - if not org or not org.quota_enabled: - # No quota enforcement if not enabled - return True - - # Get or create current cycle within the same session/transaction - cycle = await self._get_or_create_current_cycle_impl( - organization_id, session, commit=False - ) - - # Atomic check and update with row-level lock - result = await session.execute( - select(OrganizationUsageCycleModel) - .where( - and_( - OrganizationUsageCycleModel.id == cycle.id, - OrganizationUsageCycleModel.used_dograh_tokens - + estimated_tokens - <= OrganizationUsageCycleModel.quota_dograh_tokens, - ) - ) - .with_for_update(skip_locked=False) - ) - - cycle_locked = result.scalar_one_or_none() - if cycle_locked: - # Update the usage atomically - cycle_locked.used_dograh_tokens += estimated_tokens - await session.commit() - return True - - return False - - async def update_usage_after_run( - self, - organization_id: int, - actual_tokens: float, - duration_seconds: float = 0, - charge_usd: float | None = None, - ) -> None: - """Update usage after a workflow run completes with actual token count and duration. - - This method is fully atomic and safe for concurrent access from multiple processes. - """ - async with self.async_session() as session: - # Get or create current cycle within the same session/transaction - cycle = await self._get_or_create_current_cycle_impl( - organization_id, session, commit=False - ) - - # Acquire a row-level lock for atomic update - result = await session.execute( - select(OrganizationUsageCycleModel) - .where(OrganizationUsageCycleModel.id == cycle.id) - .with_for_update(skip_locked=False) - ) - cycle_locked = result.scalar_one() - - # Update usage atomically - cycle_locked.used_dograh_tokens += actual_tokens - cycle_locked.total_duration_seconds += int(round(duration_seconds)) - - # Update USD amount if provided - if charge_usd is not None: - if cycle_locked.used_amount_usd is None: - cycle_locked.used_amount_usd = 0 - cycle_locked.used_amount_usd += charge_usd - - await session.commit() - async def get_current_usage(self, organization_id: int) -> dict: - """Get current period usage information.""" + """Get current reporting-period usage information.""" async with self.async_session() as session: - # Get organization org_result = await session.execute( select(OrganizationModel).where(OrganizationModel.id == organization_id) ) @@ -201,42 +109,19 @@ class OrganizationUsageClient(BaseDBClient): organization_id, session, commit=False ) - # Calculate next refresh date - if org.quota_type == "monthly": - next_refresh = cycle.period_end + relativedelta(days=1) - else: # annual - next_refresh = cycle.period_end + relativedelta(days=1) - result = { "period_start": cycle.period_start.isoformat(), "period_end": cycle.period_end.isoformat(), "used_dograh_tokens": cycle.used_dograh_tokens, - "quota_dograh_tokens": cycle.quota_dograh_tokens, - "percentage_used": ( - round( - (cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2 - ) - if cycle.quota_dograh_tokens > 0 - else 0 - ), - "next_refresh_date": next_refresh.date().isoformat(), - "quota_enabled": org.quota_enabled, "total_duration_seconds": cycle.total_duration_seconds, } # Add USD fields if organization has pricing if org.price_per_second_usd is not None: result["used_amount_usd"] = cycle.used_amount_usd or 0 - result["quota_amount_usd"] = cycle.quota_amount_usd result["currency"] = "USD" result["price_per_second_usd"] = org.price_per_second_usd - # Calculate percentage based on USD if available - if cycle.quota_amount_usd and cycle.quota_amount_usd > 0: - result["percentage_used"] = round( - ((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2 - ) - return result async def get_usage_history( @@ -256,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient): .join(UserModel, WorkflowModel.user_id == UserModel.id) .where( UserModel.selected_organization_id == organization_id, - WorkflowRunModel.cost_info.isnot(None), + WorkflowRunModel.usage_info.isnot(None), ) .order_by(WorkflowRunModel.created_at.desc()) ) @@ -309,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient): total_tokens = 0 total_duration_seconds = 0 for run in runs: - if run.cost_info: - # Try to get dograh_token_usage first (new format) - dograh_tokens = run.cost_info.get("dograh_token_usage", 0) - # If not present, calculate from total_cost_usd (old format) - if dograh_tokens == 0 and "total_cost_usd" in run.cost_info: - dograh_tokens = round( - float(run.cost_info["total_cost_usd"]) * 100, 2 - ) - # Get call duration - call_duration = run.cost_info.get("call_duration_seconds", 0) - else: - dograh_tokens = 0 - call_duration = 0 + dograh_tokens = 0 + call_duration = (run.usage_info or {}).get("call_duration_seconds", 0) total_tokens += dograh_tokens total_duration_seconds += int(round(call_duration)) @@ -395,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient): WorkflowRunModel.initial_context, WorkflowRunModel.gathered_context, WorkflowRunModel.cost_info, + WorkflowRunModel.usage_info, WorkflowRunModel.public_access_token, ) .join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id) .join(UserModel, WorkflowModel.user_id == UserModel.id) .where( UserModel.selected_organization_id == organization_id, - WorkflowRunModel.cost_info.isnot(None), + WorkflowRunModel.usage_info.isnot(None), ) .order_by(WorkflowRunModel.created_at.desc()) ) @@ -473,11 +348,11 @@ class OrganizationUsageClient(BaseDBClient): ) config_obj = config_result.scalar_one_or_none() if config_obj and config_obj.configuration: - user_config = EffectiveAIModelConfiguration.model_validate( + effective_config = EffectiveAIModelConfiguration.model_validate( config_obj.configuration ) - if user_config.timezone and user_timezone == "UTC": - user_timezone = user_config.timezone + if effective_config.timezone and user_timezone == "UTC": + user_timezone = effective_config.timezone # Validate timezone string try: @@ -496,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient): select( date_expr.label("date"), func.sum( - WorkflowRunModel.cost_info["call_duration_seconds"].as_float() + WorkflowRunModel.usage_info["call_duration_seconds"].as_float() ).label("total_seconds"), func.count(WorkflowRunModel.id).label("call_count"), ) @@ -545,83 +420,11 @@ class OrganizationUsageClient(BaseDBClient): "currency": "USD", } - async def update_organization_quota( - self, - organization_id: int, - quota_type: str, - quota_dograh_tokens: int, - quota_reset_day: Optional[int] = None, - quota_start_date: Optional[datetime] = None, - ) -> OrganizationModel: - """Update organization quota settings.""" - async with self.async_session() as session: - result = await session.execute( - select(OrganizationModel).where(OrganizationModel.id == organization_id) - ) - org = result.scalar_one() - - org.quota_type = quota_type - org.quota_dograh_tokens = quota_dograh_tokens - org.quota_enabled = True - - if quota_type == "monthly" and quota_reset_day: - org.quota_reset_day = quota_reset_day - elif quota_type == "annual" and quota_start_date: - org.quota_start_date = quota_start_date - - await session.commit() - await session.refresh(org) - return org - - def _calculate_current_period( - self, org: OrganizationModel - ) -> tuple[datetime, datetime]: - """Calculate the current billing period based on organization settings.""" + def _calculate_current_period(self) -> tuple[datetime, datetime]: + """Calculate the current calendar-month reporting period.""" now = datetime.now(timezone.utc) - if org.quota_type == "monthly": - # Find the start of the current billing month - reset_day = org.quota_reset_day - - # Handle month boundaries - if now.day >= reset_day: - period_start = now.replace( - day=reset_day, hour=0, minute=0, second=0, microsecond=0 - ) - else: - # Previous month - period_start = (now - relativedelta(months=1)).replace( - day=reset_day, hour=0, minute=0, second=0, microsecond=0 - ) - - # End is one month later minus 1 second - period_end = ( - period_start + relativedelta(months=1) - relativedelta(seconds=1) - ) - - else: # annual - if not org.quota_start_date: - # Default to calendar year - period_start = now.replace( - month=1, day=1, hour=0, minute=0, second=0, microsecond=0 - ) - period_end = ( - period_start + relativedelta(years=1) - relativedelta(seconds=1) - ) - else: - # Find current annual period - start_date = org.quota_start_date.replace(tzinfo=timezone.utc) - years_diff = now.year - start_date.year - - # Adjust for whether we've passed the anniversary - if now.month < start_date.month or ( - now.month == start_date.month and now.day < start_date.day - ): - years_diff -= 1 - - period_start = start_date + relativedelta(years=years_diff) - period_end = ( - period_start + relativedelta(years=1) - relativedelta(seconds=1) - ) + period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + period_end = period_start + relativedelta(months=1) - relativedelta(seconds=1) return period_start, period_end diff --git a/api/db/user_client.py b/api/db/user_client.py index 9c4476f2..4ea0bca9 100644 --- a/api/db/user_client.py +++ b/api/db/user_client.py @@ -8,7 +8,7 @@ from sqlalchemy.future import select from api.db.base_client import BaseDBClient from api.db.models import UserConfigurationModel, UserModel -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration class UserClient(BaseDBClient): diff --git a/api/db/workflow_run_client.py b/api/db/workflow_run_client.py index 57c3e02b..497230ad 100644 --- a/api/db/workflow_run_client.py +++ b/api/db/workflow_run_client.py @@ -16,6 +16,7 @@ from api.db.models import ( ) from api.enums import CallType, StorageBackend from api.schemas.workflow import WorkflowRunResponseSchema +from api.services.workflow.run_usage_response import format_public_cost_info class WorkflowRunClient(BaseDBClient): @@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient): "is_completed": run.is_completed, "recording_url": run.recording_url, "transcript_url": run.transcript_url, - "cost_info": { - "dograh_token_usage": ( - run.cost_info.get("dograh_token_usage") - if run.cost_info - and "dograh_token_usage" in run.cost_info - else round( - float(run.cost_info.get("total_cost_usd", 0)) * 100, - 2, - ) - if run.cost_info and "total_cost_usd" in run.cost_info - else 0 - ), - "call_duration_seconds": int( - round(run.cost_info.get("call_duration_seconds") or 0) - ) - if run.cost_info - else None, - } - if run.cost_info - else None, + "cost_info": format_public_cost_info( + run.cost_info, run.usage_info + ), "definition_id": run.definition_id, "initial_context": run.initial_context, "gathered_context": run.gathered_context, diff --git a/api/routes/knowledge_base.py b/api/routes/knowledge_base.py index d9156871..bd0ba046 100644 --- a/api/routes/knowledge_base.py +++ b/api/routes/knowledge_base.py @@ -384,7 +384,7 @@ async def search_chunks( user_id=user.id, organization_id=user.selected_organization_id, ) - user_config = resolved_config.effective + effective_config = resolved_config.effective embeddings_api_key = None embeddings_model = None embeddings_provider = None @@ -392,17 +392,17 @@ async def search_chunks( embeddings_endpoint = None embeddings_api_version = None - if user_config.embeddings: - embeddings_api_key = user_config.embeddings.api_key - embeddings_model = user_config.embeddings.model - embeddings_provider = getattr(user_config.embeddings, "provider", None) - embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) + if effective_config.embeddings: + embeddings_api_key = effective_config.embeddings.api_key + embeddings_model = effective_config.embeddings.model + embeddings_provider = getattr(effective_config.embeddings, "provider", None) + embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None) embeddings_base_url = apply_managed_embeddings_base_url( provider=embeddings_provider, - base_url=getattr(user_config.embeddings, "base_url", None), + base_url=getattr(effective_config.embeddings, "base_url", None), ) embeddings_api_version = getattr( - user_config.embeddings, "api_version", None + effective_config.embeddings, "api_version", None ) # Initialize embedding service based on provider diff --git a/api/routes/organization.py b/api/routes/organization.py index 4fb8e850..8f8e4cbe 100644 --- a/api/routes/organization.py +++ b/api/routes/organization.py @@ -5,7 +5,11 @@ from loguru import logger from pydantic import BaseModel from sqlalchemy.exc import IntegrityError -from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT +from api.constants import ( + DEFAULT_CAMPAIGN_RETRY_CONFIG, + DEFAULT_ORG_CONCURRENCY_LIMIT, + DEPLOYMENT_MODE, +) from api.db import db_client from api.db.models import UserModel from api.db.telephony_configuration_client import TelephonyConfigurationInUseError @@ -55,6 +59,11 @@ from api.services.configuration.registry import ( ServiceProviders, ServiceType, ) +from api.services.mps_billing import ensure_hosted_mps_billing_account_v2 +from api.services.organization_context import ( + OrganizationContextResponse, + get_organization_context, +) from api.services.organization_preferences import ( get_organization_preferences, upsert_organization_preferences, @@ -129,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel): telnyx_missing_webhook_public_key_count: int +@router.get("/context", response_model=OrganizationContextResponse) +async def get_current_organization_context(user: UserModel = Depends(get_user)): + """Return organization-scoped configuration signals owned by Dograh.""" + return await get_organization_context(user) + + @router.get( "/telephony-providers/metadata", response_model=TelephonyProvidersMetadataResponse, @@ -349,6 +364,23 @@ async def migrate_model_configuration_v2( except ValueError as exc: raise HTTPException(status_code=422, detail=exc.args[0]) + if DEPLOYMENT_MODE != "oss": + try: + await ensure_hosted_mps_billing_account_v2( + organization_id, + created_by=str(user.provider_id), + ) + except Exception as exc: + logger.error( + "Failed to initialize MPS billing v2 account for organization {}: {}", + organization_id, + exc, + ) + raise HTTPException( + status_code=502, + detail="Failed to initialize MPS billing v2 account", + ) + await upsert_organization_ai_model_configuration_v2( organization_id, configuration, diff --git a/api/routes/organization_usage.py b/api/routes/organization_usage.py index 8e75a2c8..3912745b 100644 --- a/api/routes/organization_usage.py +++ b/api/routes/organization_usage.py @@ -1,16 +1,16 @@ import json from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from loguru import logger from pydantic import BaseModel, Field -from api.constants import DEPLOYMENT_MODE +from api.constants import DEPLOYMENT_MODE, UI_APP_URL from api.db import db_client from api.db.models import UserModel -from api.services.auth.depends import get_user +from api.services.auth.depends import get_user, get_user_with_selected_organization from api.services.mps_service_key_client import mps_service_key_client from api.services.reports import generate_usage_runs_report_csv from api.utils.artifacts import artifact_url @@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel): period_start: str period_end: str used_dograh_tokens: float - quota_dograh_tokens: int - percentage_used: float - next_refresh_date: str - quota_enabled: bool total_duration_seconds: int - # New USD fields used_amount_usd: Optional[float] = None - quota_amount_usd: Optional[float] = None currency: Optional[str] = None price_per_second_usd: Optional[float] = None @@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel): total_quota: float +class MPSCreditPurchaseUrlResponse(BaseModel): + checkout_url: str + + +class MPSBillingAccountResponse(BaseModel): + id: int + organization_id: int + billing_mode: str + cached_balance_credits: float + currency: str + + +class MPSCreditLedgerEntryResponse(BaseModel): + id: int + entry_type: str + origin: Optional[str] = None + credits_delta: float + balance_after: float + amount_minor: Optional[int] = None + amount_currency: Optional[str] = None + payment_order_id: Optional[int] = None + metric_code: Optional[str] = None + correlation_id: Optional[str] = None + aggregation_key: Optional[str] = None + usage_event_id: Optional[int] = None + workflow_run_id: Optional[int] = None + workflow_id: Optional[int] = None + billable_quantity: Optional[float] = None + quantity_unit: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + created_at: str + + +class MPSBillingCreditsResponse(BaseModel): + billing_version: Literal["legacy", "v2"] + total_credits_used: float = 0.0 + remaining_credits: float = 0.0 + total_quota: float = 0.0 + account: Optional[MPSBillingAccountResponse] = None + ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list) + total_count: int = 0 + page: int = 1 + limit: int = 50 + total_pages: int = 0 + + +def _optional_int(value: Any) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + class WorkflowRunUsageResponse(BaseModel): id: int workflow_id: int @@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel): @router.get("/usage/current-period", response_model=CurrentUsageResponse) async def get_current_period_usage(user: UserModel = Depends(get_user)): - """Get current billing period usage for the user's organization.""" + """Get current reporting-period usage for the user's organization.""" if not user.selected_organization_id: raise HTTPException(status_code=400, detail="No organization selected") @@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)): raise HTTPException(status_code=500, detail=str(e)) +async def _get_mps_billing_account_status( + user: UserModel, organization_id: int +) -> Optional[dict]: + return await mps_service_key_client.get_billing_account_status( + organization_id=organization_id, + created_by=str(user.provider_id), + ) + + +def _is_mps_billing_v2(account: Optional[dict]) -> bool: + return bool(account and account.get("billing_mode") == "v2") + + +async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse: + if DEPLOYMENT_MODE == "oss": + usage = await mps_service_key_client.get_usage_by_created_by( + str(user.provider_id) + ) + else: + if not user.selected_organization_id: + raise HTTPException(status_code=400, detail="No organization selected") + usage = await mps_service_key_client.get_usage_by_organization( + user.selected_organization_id + ) + + total_used = float(usage.get("total_credits_used", 0.0)) + total_remaining = float(usage.get("remaining_credits", 0.0)) + return MPSBillingCreditsResponse( + billing_version="legacy", + total_credits_used=total_used, + remaining_credits=total_remaining, + total_quota=total_used + total_remaining, + ) + + +@router.get("/billing/credits", response_model=MPSBillingCreditsResponse) +async def get_billing_credits( + page: int = Query(1, ge=1), + limit: int = Query(50, ge=1, le=100), + user: UserModel = Depends(get_user), +): + """Return legacy MPS credits or paginated v2 billing ledger details for the org.""" + try: + if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id: + return await _legacy_mps_credits_response(user) + + organization_id = user.selected_organization_id + account_status = await _get_mps_billing_account_status(user, organization_id) + if not _is_mps_billing_v2(account_status): + return await _legacy_mps_credits_response(user) + + ledger = await mps_service_key_client.get_credit_ledger( + organization_id=organization_id, + page=page, + limit=limit, + created_by=str(user.provider_id), + ) + account = ledger.get("account") or {} + ledger_entries = ledger.get("ledger_entries") or [] + total_count = int(ledger.get("total_count") or len(ledger_entries)) + response_limit = int(ledger.get("limit") or limit) + total_pages = int( + ledger.get("total_pages") + or ((total_count + response_limit - 1) // response_limit) + ) + workflow_ids_by_run_id: dict[int, int] = {} + workflow_run_ids = { + workflow_run_id + for entry in ledger_entries + if (workflow_run_id := _optional_int(entry.get("workflow_run_id"))) + is not None + } + for workflow_run_id in workflow_run_ids: + workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) + if ( + workflow_run + and workflow_run.workflow + and workflow_run.workflow.organization_id == organization_id + ): + workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id + + balance = float(account.get("cached_balance_credits") or 0.0) + total_debits = sum( + abs(float(entry.get("credits_delta") or 0.0)) + for entry in ledger_entries + if float(entry.get("credits_delta") or 0.0) < 0 + ) + if ledger.get("total_debits_credits") is not None: + total_debits = float(ledger["total_debits_credits"]) + + return MPSBillingCreditsResponse( + billing_version="v2", + total_credits_used=total_debits, + remaining_credits=balance, + total_quota=balance + total_debits, + account=MPSBillingAccountResponse( + id=int(account["id"]), + organization_id=int(account["organization_id"]), + billing_mode=str(account["billing_mode"]), + cached_balance_credits=balance, + currency=str(account.get("currency") or "USD"), + ), + ledger_entries=[ + MPSCreditLedgerEntryResponse( + id=int(entry["id"]), + entry_type=str(entry["entry_type"]), + origin=entry.get("origin"), + credits_delta=float(entry.get("credits_delta") or 0.0), + balance_after=float(entry.get("balance_after") or 0.0), + amount_minor=entry.get("amount_minor"), + amount_currency=entry.get("amount_currency"), + payment_order_id=entry.get("payment_order_id"), + metric_code=entry.get("metric_code"), + correlation_id=entry.get("correlation_id"), + aggregation_key=entry.get("aggregation_key"), + usage_event_id=_optional_int(entry.get("usage_event_id")), + workflow_run_id=_optional_int(entry.get("workflow_run_id")), + workflow_id=workflow_ids_by_run_id.get( + _optional_int(entry.get("workflow_run_id")) + ) + if entry.get("workflow_run_id") is not None + else None, + billable_quantity=float(entry["billable_quantity"]) + if entry.get("billable_quantity") is not None + else None, + quantity_unit=entry.get("quantity_unit"), + metadata=entry.get("metadata") or {}, + created_at=str(entry["created_at"]), + ) + for entry in ledger_entries + ], + total_count=total_count, + page=int(ledger.get("page") or page), + limit=response_limit, + total_pages=total_pages, + ) + except HTTPException: + raise + except Exception as exc: + logger.error(f"Failed to fetch billing credits: {exc}") + raise HTTPException(status_code=500, detail=str(exc)) + + +@router.post( + "/usage/mps-credits/purchase-url", + response_model=MPSCreditPurchaseUrlResponse, +) +async def create_mps_credit_purchase_url( + user: UserModel = Depends(get_user_with_selected_organization), +): + """Create a checkout URL for organizations using Dograh-managed MPS v2.""" + if DEPLOYMENT_MODE == "oss": + raise HTTPException( + status_code=404, + detail="Credit purchases are not available in OSS mode", + ) + + organization_id = user.selected_organization_id + assert organization_id is not None + account_status = await _get_mps_billing_account_status(user, organization_id) + if not _is_mps_billing_v2(account_status): + raise HTTPException( + status_code=403, + detail=( + "Credit purchases are available only for organizations using billing v2" + ), + ) + + try: + session = await mps_service_key_client.create_credit_purchase_url( + organization_id=organization_id, + created_by=str(user.provider_id), + return_url=f"{UI_APP_URL.rstrip('/')}/billing", + billing_details={ + "source": "dograh_billing", + "dograh_user_id": str(user.id), + "dograh_provider_id": str(user.provider_id), + }, + ) + except Exception as exc: + logger.error(f"Failed to create MPS credit purchase URL: {exc}") + raise HTTPException( + status_code=502, + detail="Failed to create credit purchase URL", + ) + + checkout_url = session.get("checkout_url") + if not checkout_url: + logger.error(f"MPS checkout session response missing checkout_url: {session}") + raise HTTPException( + status_code=502, + detail="MPS checkout session response missing checkout_url", + ) + return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url) + + FILTERS_DESCRIPTION = """\ JSON-encoded array of filter objects. Each object has the shape: diff --git a/api/routes/workflow.py b/api/routes/workflow.py index 9157c5cf..06e5fdf9 100644 --- a/api/routes/workflow.py +++ b/api/routes/workflow.py @@ -41,12 +41,15 @@ from api.services.configuration.resolve import ( ) from api.services.mps_service_key_client import mps_service_key_client from api.services.posthog_client import capture_event -from api.services.pricing.run_usage_response import format_public_usage_info from api.services.reports import generate_workflow_report_csv from api.services.storage import storage_fs from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition from api.services.workflow.duplicate import duplicate_workflow from api.services.workflow.errors import ItemKind, WorkflowError +from api.services.workflow.run_usage_response import ( + format_public_cost_info, + format_public_usage_info, +) from api.services.workflow.trigger_paths import ( TriggerPathIssue, ensure_trigger_paths, @@ -1053,13 +1056,15 @@ async def update_workflow( user_id=user.id, organization_id=user.selected_organization_id, ) - user_config = resolved_config.effective + effective_config = resolved_config.effective try: enriched_overrides = enrich_overrides_with_api_keys( workflow_configurations["model_overrides"], - user_config, + effective_config, + ) + effective = resolve_effective_config( + effective_config, enriched_overrides ) - effective = resolve_effective_config(user_config, enriched_overrides) if resolved_config.source == "organization_v2": v2_override = convert_legacy_ai_model_configuration_to_v2(effective) await UserConfigurationValidator().validate( @@ -1264,22 +1269,7 @@ async def get_workflow_run( "transcript_public_url": artifact_url(public_access_token, "transcript"), "recording_public_url": artifact_url(public_access_token, "recording"), "public_access_token": public_access_token, - "cost_info": { - "dograh_token_usage": ( - run.cost_info.get("dograh_token_usage") - if run.cost_info and "dograh_token_usage" in run.cost_info - else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2) - if run.cost_info and "total_cost_usd" in run.cost_info - else 0 - ), - "call_duration_seconds": int( - round(run.cost_info.get("call_duration_seconds")) - ) - if run.cost_info and run.cost_info.get("call_duration_seconds") is not None - else None, - } - if run.cost_info - else None, + "cost_info": format_public_cost_info(run.cost_info, run.usage_info), "usage_info": format_public_usage_info(run.usage_info), "created_at": run.created_at, "definition_id": run.definition_id, diff --git a/api/schemas/ai_model_configuration.py b/api/schemas/ai_model_configuration.py index dcc3a6e7..c5403b04 100644 --- a/api/schemas/ai_model_configuration.py +++ b/api/schemas/ai_model_configuration.py @@ -1,10 +1,10 @@ from __future__ import annotations +from datetime import datetime from typing import Literal from pydantic import BaseModel, Field, model_validator -from api.schemas.user_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ( DograhEmbeddingsConfiguration, DograhLLMService, @@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default" DOGRAH_DEFAULT_LANGUAGE = "multi" +class EffectiveAIModelConfiguration(BaseModel): + llm: LLMConfig | None = None + stt: STTConfig | None = None + tts: TTSConfig | None = None + embeddings: EmbeddingsConfig | None = None + realtime: RealtimeConfig | None = None + is_realtime: bool = False + managed_service_version: int | None = None + test_phone_number: str | None = None + timezone: str | None = None + last_validated_at: datetime | None = None + + @model_validator(mode="before") + @classmethod + def strip_incomplete_realtime_when_disabled(cls, data): + """Skip realtime validation when is_realtime is False and api_key is missing.""" + if isinstance(data, dict) and not data.get("is_realtime", False): + realtime = data.get("realtime") + if isinstance(realtime, dict) and not realtime.get("api_key"): + data.pop("realtime", None) + return data + + class DograhManagedAIModelConfiguration(BaseModel): api_key: str voice: str = DOGRAH_DEFAULT_VOICE @@ -160,6 +183,7 @@ def _compile_dograh_configuration( model="default", ), is_realtime=False, + managed_service_version=2, ) diff --git a/api/schemas/user_configuration.py b/api/schemas/user_configuration.py deleted file mode 100644 index fc958a5b..00000000 --- a/api/schemas/user_configuration.py +++ /dev/null @@ -1,33 +0,0 @@ -from datetime import datetime - -from pydantic import BaseModel, model_validator - -from api.services.configuration.registry import ( - EmbeddingsConfig, - LLMConfig, - RealtimeConfig, - STTConfig, - TTSConfig, -) - - -class EffectiveAIModelConfiguration(BaseModel): - llm: LLMConfig | None = None - stt: STTConfig | None = None - tts: TTSConfig | None = None - embeddings: EmbeddingsConfig | None = None - realtime: RealtimeConfig | None = None - is_realtime: bool = False - test_phone_number: str | None = None - timezone: str | None = None - last_validated_at: datetime | None = None - - @model_validator(mode="before") - @classmethod - def strip_incomplete_realtime_when_disabled(cls, data): - """Skip realtime validation when is_realtime is False and api_key is missing.""" - if isinstance(data, dict) and not data.get("is_realtime", False): - realtime = data.get("realtime") - if isinstance(realtime, dict) and not realtime.get("api_key"): - data.pop("realtime", None) - return data diff --git a/api/services/auth/depends.py b/api/services/auth/depends.py index d9e24684..019dbc2f 100644 --- a/api/services/auth/depends.py +++ b/api/services/auth/depends.py @@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL from api.db import db_client from api.db.models import UserModel from api.enums import PostHogEvent -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.auth.stack_auth import stackauth from api.services.configuration.registry import ServiceProviders +from api.services.mps_billing import ensure_hosted_mps_billing_account_v2 from api.services.posthog_client import capture_event from api.utils.auth import decode_jwt_token @@ -110,6 +111,19 @@ async def get_user( # This prevents race conditions where multiple concurrent requests # might try to create configurations if org_was_created: + try: + await ensure_hosted_mps_billing_account_v2( + organization.id, + created_by=str(stack_user["id"]), + ) + except Exception: + logger.warning( + "Failed to initialize hosted MPS billing account for " + "organization {}", + organization.id, + exc_info=True, + ) + existing_cfg = await db_client.get_user_configurations(user_model.id) if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt): mps_config = await create_user_configuration_with_mps_key( @@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key( response = await client.post( f"{MPS_API_URL}/api/v1/service-keys/", json={ - "name": f"Default Dograh Model Service Key", + "name": "Default Dograh Model Service Key", "description": "Auto-generated key for OSS user", "expires_in_days": 7, # Short-lived for OSS "created_by": user_provider_id, @@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key( response = await client.post( f"{MPS_API_URL}/api/v1/service-keys/", json={ - "name": f"Default Dograh Model Service Key", + "name": "Default Dograh Model Service Key", "description": f"Auto-generated key for organization {organization_id}", "organization_id": organization_id, "expires_in_days": 90, # Longer-lived for authenticated users @@ -285,8 +299,8 @@ async def create_user_configuration_with_mps_key( "model": "default", }, } - user_config = EffectiveAIModelConfiguration(**configuration) - return user_config + effective_config = EffectiveAIModelConfiguration(**configuration) + return effective_config else: logger.warning( f"Failed to get MPS service key: {response.status_code} - {response.text}" diff --git a/api/services/configuration/ai_model_configuration.py b/api/services/configuration/ai_model_configuration.py index 1b9a00f6..c5331515 100644 --- a/api/services/configuration/ai_model_configuration.py +++ b/api/services/configuration/ai_model_configuration.py @@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import ( BYOKPipelineAIModelConfiguration, BYOKRealtimeAIModelConfiguration, DograhManagedAIModelConfiguration, + EffectiveAIModelConfiguration, OrganizationAIModelConfigurationV2, compile_ai_model_configuration_v2, ) -from api.schemas.user_configuration import EffectiveAIModelConfiguration from api.services.configuration.masking import ( SERVICE_SECRET_FIELDS, contains_masked_key, diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index e8f5bfa7..b1996879 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -8,7 +8,7 @@ from groq import Groq # from pyneuphonic import Neuphonic # except ImportError: # Neuphonic = None -from api.schemas.user_configuration import ( +from api.schemas.ai_model_configuration import ( EffectiveAIModelConfiguration, ) from api.services.configuration.registry import ServiceConfig, ServiceProviders @@ -75,21 +75,21 @@ class UserConfigurationValidator: status_list = [] status_list.extend(self._validate_service(configuration.llm, "llm")) - status_list.extend(self._validate_service(configuration.stt, "stt")) - status_list.extend(self._validate_service(configuration.tts, "tts")) - # Embeddings is optional - only validate if configured - status_list.extend( - self._validate_service( - configuration.embeddings, "embeddings", required=False - ) - ) - # Realtime is optional - only validate if is_realtime is enabled if configuration.is_realtime: status_list.extend( self._validate_service( configuration.realtime, "realtime", required=True ) ) + else: + status_list.extend(self._validate_service(configuration.stt, "stt")) + status_list.extend(self._validate_service(configuration.tts, "tts")) + # Embeddings is optional - only validate if configured + status_list.extend( + self._validate_service( + configuration.embeddings, "embeddings", required=False + ) + ) if status_list: raise ValueError(status_list) diff --git a/api/services/configuration/masking.py b/api/services/configuration/masking.py index c3fa4bfc..a7e1af6a 100644 --- a/api/services/configuration/masking.py +++ b/api/services/configuration/masking.py @@ -12,7 +12,7 @@ The rules are simple: import copy from typing import Any, Dict, Optional -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ServiceConfig from api.services.integrations import get_node_secret_fields diff --git a/api/services/configuration/merge.py b/api/services/configuration/merge.py index 1b174ee8..3100fa45 100644 --- a/api/services/configuration/merge.py +++ b/api/services/configuration/merge.py @@ -7,7 +7,7 @@ stored, while honouring masked API keys. import copy from typing import Dict -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.masking import ( MODEL_OVERRIDE_FIELDS, SERVICE_SECRET_FIELDS, diff --git a/api/services/configuration/resolve.py b/api/services/configuration/resolve.py index a33f5c09..5cbf11ef 100644 --- a/api/services/configuration/resolve.py +++ b/api/services/configuration/resolve.py @@ -4,7 +4,7 @@ from __future__ import annotations import copy -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ( REGISTRY, ServiceType, diff --git a/api/services/gen_ai/embedding/openai_service.py b/api/services/gen_ai/embedding/openai_service.py index da5d3d4d..1081889e 100644 --- a/api/services/gen_ai/embedding/openai_service.py +++ b/api/services/gen_ai/embedding/openai_service.py @@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService): api_key: Optional[str] = None, model_id: str = DEFAULT_MODEL_ID, base_url: Optional[str] = None, + default_headers: Optional[Dict[str, str]] = None, ): """Initialize the OpenAI embedding service. @@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService): field_name="base_url", ) client_kwargs["base_url"] = base_url + if default_headers: + client_kwargs["default_headers"] = default_headers self.client = AsyncOpenAI(**client_kwargs) logger.info(f"OpenAI embedding service initialized with model: {model_id}") else: diff --git a/api/services/managed_model_services.py b/api/services/managed_model_services.py new file mode 100644 index 00000000..b6992aaf --- /dev/null +++ b/api/services/managed_model_services.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import Any + +from loguru import logger + +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration +from api.services.configuration.registry import ServiceProviders +from api.services.mps_service_key_client import mps_service_key_client + +MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id" + + +def uses_managed_model_services_v2( + ai_model_config: EffectiveAIModelConfiguration | None, +) -> bool: + if ( + ai_model_config is None + or getattr(ai_model_config, "managed_service_version", None) != 2 + ): + return False + + return any( + _is_dograh_service(getattr(ai_model_config, section_name, None)) + for section_name in ("llm", "tts", "stt", "embeddings") + ) + + +def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None: + if not initial_context: + return None + correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) + if correlation_id is None: + return None + return str(correlation_id) + + +async def ensure_mps_correlation_id( + *, + ai_model_config: EffectiveAIModelConfiguration, + workflow_run_id: int, + initial_context: dict[str, Any] | None, +) -> str | None: + existing = get_mps_correlation_id(initial_context) + if existing: + return existing + + if not uses_managed_model_services_v2(ai_model_config): + return None + + service_key = _get_dograh_service_api_key(ai_model_config) + if not service_key: + raise ValueError( + "Managed model services v2 requires a Dograh service key before the run starts." + ) + + response = await mps_service_key_client.create_correlation_id( + service_key=service_key, + workflow_run_id=workflow_run_id, + ) + correlation_id = response.get("correlation_id") + if not correlation_id: + raise ValueError("MPS correlation-id response did not include correlation_id") + + correlation_id = str(correlation_id) + logger.info( + "Minted MPS correlation id {} for workflow run {}", + correlation_id, + workflow_run_id, + ) + return correlation_id + + +def _is_dograh_service(service: Any) -> bool: + provider = getattr(service, "provider", None) + return ( + provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value + ) + + +def _get_dograh_service_api_key( + ai_model_config: EffectiveAIModelConfiguration, +) -> str | None: + for section_name in ("llm", "tts", "stt", "embeddings"): + service = getattr(ai_model_config, section_name, None) + if not _is_dograh_service(service): + continue + + if hasattr(service, "get_all_api_keys"): + keys = service.get_all_api_keys() + if keys: + return keys[0] + + api_key = getattr(service, "api_key", None) + if isinstance(api_key, str) and api_key: + return api_key + + return None diff --git a/api/services/mps_billing.py b/api/services/mps_billing.py new file mode 100644 index 00000000..10a27c90 --- /dev/null +++ b/api/services/mps_billing.py @@ -0,0 +1,23 @@ +from typing import Optional + +from api.constants import DEPLOYMENT_MODE +from api.services.mps_service_key_client import mps_service_key_client + + +async def ensure_hosted_mps_billing_account_v2( + organization_id: int, + *, + created_by: Optional[str] = None, +) -> Optional[dict]: + """Ensure hosted orgs have an MPS billing v2 account. + + OSS deployments use legacy per-key quota accounting and do not create MPS + billing accounts. + """ + if DEPLOYMENT_MODE == "oss": + return None + + return await mps_service_key_client.ensure_billing_account_v2( + organization_id=organization_id, + created_by=created_by, + ) diff --git a/api/services/mps_service_key_client.py b/api/services/mps_service_key_client.py index 2c7fc56b..4f30341d 100644 --- a/api/services/mps_service_key_client.py +++ b/api/services/mps_service_key_client.py @@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana Service keys are stored and managed entirely in MPS, not in the local database. """ +import asyncio from typing import List, Optional import httpx @@ -353,6 +354,234 @@ class MPSServiceKeyClient: response=response, ) + async def create_credit_purchase_url( + self, + organization_id: int, + created_by: Optional[str] = None, + return_url: Optional[str] = None, + billing_details: Optional[dict] = None, + ) -> dict: + """Create a short-lived MPS checkout URL for adding organization credits.""" + payload = { + "created_by": created_by, + "return_url": return_url, + "billing_details": billing_details or {}, + } + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions", + json=payload, + headers=self._get_headers( + organization_id=organization_id, + created_by=created_by, + ), + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to create MPS credit purchase URL: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to create MPS credit purchase URL: {response.text}", + request=response.request, + response=response, + ) + + async def get_credit_ledger( + self, + organization_id: int, + page: int = 1, + limit: int = 50, + created_by: Optional[str] = None, + ) -> dict: + """Get the MPS v2 billing account balance and recent credit ledger.""" + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger", + params={"page": page, "limit": limit}, + headers=self._get_headers( + organization_id=organization_id, + created_by=created_by, + ), + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to get MPS credit ledger: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to get MPS credit ledger: {response.text}", + request=response.request, + response=response, + ) + + async def get_billing_account_status( + self, + organization_id: int, + created_by: Optional[str] = None, + ) -> Optional[dict]: + """Get an existing MPS v2 billing account without creating one.""" + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status", + headers=self._get_headers( + organization_id=organization_id, + created_by=created_by, + ), + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to get MPS billing account status: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to get MPS billing account status: {response.text}", + request=response.request, + response=response, + ) + + async def ensure_billing_account_v2( + self, + organization_id: int, + created_by: Optional[str] = None, + ) -> dict: + """Create or return the MPS v2 billing account for an organization.""" + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance", + headers=self._get_headers( + organization_id=organization_id, + created_by=created_by, + ), + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to ensure MPS billing account v2: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to ensure MPS billing account v2: {response.text}", + request=response.request, + response=response, + ) + + async def create_correlation_id( + self, + *, + service_key: str, + workflow_run_id: int | None = None, + ) -> dict: + """Mint a server-generated correlation ID for managed model services.""" + payload: dict[str, int] = {} + if workflow_run_id is not None: + payload["workflow_run_id"] = workflow_run_id + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/api/v1/service-keys/correlation-id/self", + json=payload, + headers={ + "Authorization": f"Bearer {service_key}", + "Content-Type": "application/json", + }, + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to create correlation ID: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to create correlation ID: {response.text}", + request=response.request, + response=response, + ) + + async def report_platform_usage( + self, + *, + organization_id: int, + correlation_id: Optional[str] = None, + duration_seconds: Optional[float] = None, + workflow_run_id: int | None = None, + metadata: Optional[dict] = None, + max_attempts: int = 3, + ) -> dict: + """Report hosted Dograh platform usage for a completed workflow run.""" + if DEPLOYMENT_MODE == "oss": + raise ValueError("OSS deployments must not report platform usage to MPS") + if not correlation_id and duration_seconds is None: + raise ValueError( + "Platform usage reports require correlation_id or duration_seconds" + ) + + payload: dict = { + "metadata": metadata or {}, + } + if correlation_id: + payload["correlation_id"] = correlation_id + if duration_seconds is not None: + payload["duration_seconds"] = duration_seconds + if workflow_run_id is not None: + payload["workflow_run_id"] = workflow_run_id + + max_attempts = max(1, max_attempts) + last_response: httpx.Response | None = None + async with httpx.AsyncClient(timeout=self.timeout) as client: + for attempt in range(1, max_attempts + 1): + response = await client.post( + ( + f"{self.base_url}/api/v1/billing/accounts/" + f"{organization_id}/platform-usage" + ), + json=payload, + headers=self._get_headers(organization_id=organization_id), + ) + last_response = response + + if response.status_code == 200: + return response.json() + + should_retry = ( + response.status_code == 409 + and "usage_not_ready" in response.text + and attempt < max_attempts + ) + if should_retry: + await asyncio.sleep(attempt) + continue + + logger.error( + "Failed to report platform usage: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to report platform usage: {response.text}", + request=response.request, + response=response, + ) + + raise httpx.HTTPStatusError( + "Failed to report platform usage", + request=last_response.request, + response=last_response, + ) + async def transcribe_audio( self, audio_data: bytes, diff --git a/api/services/organization_context.py b/api/services/organization_context.py new file mode 100644 index 00000000..b17b8f4f --- /dev/null +++ b/api/services/organization_context.py @@ -0,0 +1,50 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + +from api.db import db_client +from api.db.models import UserModel +from api.services.configuration.ai_model_configuration import ( + get_resolved_ai_model_configuration, +) + + +class OrganizationModelServicesContext(BaseModel): + config_source: Literal["organization_v2", "legacy_user_v1", "empty"] + has_model_configuration_v2: bool + managed_service_version: Optional[int] = None + uses_managed_service_v2: bool + + +class OrganizationContextResponse(BaseModel): + organization_id: Optional[int] = None + organization_provider_id: Optional[str] = None + model_services: OrganizationModelServicesContext + + +async def get_organization_context(user: UserModel) -> OrganizationContextResponse: + organization_id = user.selected_organization_id + organization = ( + await db_client.get_organization_by_id(organization_id) + if organization_id + else None + ) + + resolved = await get_resolved_ai_model_configuration( + user_id=user.id, + organization_id=organization_id, + ) + managed_service_version = resolved.effective.managed_service_version + + return OrganizationContextResponse( + organization_id=organization_id, + organization_provider_id=organization.provider_id if organization else None, + model_services=OrganizationModelServicesContext( + config_source=resolved.source, + has_model_configuration_v2=resolved.source == "organization_v2", + managed_service_version=managed_service_version, + uses_managed_service_v2=( + resolved.source == "organization_v2" and managed_service_version == 2 + ), + ), + ) diff --git a/api/services/pipecat/run_pipeline.py b/api/services/pipecat/run_pipeline.py index 63c11f53..07286901 100644 --- a/api/services/pipecat/run_pipeline.py +++ b/api/services/pipecat/run_pipeline.py @@ -162,15 +162,13 @@ async def run_pipeline_telephony( workflow_id: Workflow being executed. workflow_run_id: Workflow run row. user_id: Owner of the workflow. - call_id: Provider call identifier (stored in cost_info for billing). + call_id: Provider call identifier. transport_kwargs: Provider-specific kwargs forwarded to the transport factory (e.g. stream_sid + call_sid for Twilio). """ logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}") set_current_run_id(workflow_run_id) - await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id}) - workflow = await db_client.get_workflow(workflow_id, user_id) if workflow: set_current_org_id(workflow.organization_id) @@ -340,7 +338,7 @@ async def _run_pipeline( if workflow_run.is_completed: raise HTTPException(status_code=400, detail="Workflow run already completed") - merged_call_context_vars = workflow_run.initial_context + merged_call_context_vars = dict(workflow_run.initial_context or {}) # If there is some extra call_context_vars, fold them in. Persistence # happens once below, after runtime_configuration is also resolved. if call_context_vars: @@ -398,6 +396,19 @@ async def _run_pipeline( else: user_config = resolved_user_config + from api.services.managed_model_services import ( + MPS_CORRELATION_ID_CONTEXT_KEY, + ensure_mps_correlation_id, + ) + + mps_correlation_id = await ensure_mps_correlation_id( + ai_model_config=user_config, + workflow_run_id=workflow_run_id, + initial_context=merged_call_context_vars, + ) + if mps_correlation_id: + merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id + # Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live) is_realtime = user_config.is_realtime and user_config.realtime is not None @@ -409,11 +420,23 @@ async def _run_pipeline( # Realtime services don't implement run_inference, so create a # separate text LLM for variable extraction and other out-of-band # inference calls. - inference_llm = create_llm_service(user_config) + inference_llm = create_llm_service( + user_config, + correlation_id=mps_correlation_id, + ) else: - stt = create_stt_service(user_config, audio_config, keyterms=keyterms) - tts = create_tts_service(user_config, audio_config) - llm = create_llm_service(user_config) + stt = create_stt_service( + user_config, + audio_config, + keyterms=keyterms, + correlation_id=mps_correlation_id, + ) + tts = create_tts_service( + user_config, + audio_config, + correlation_id=mps_correlation_id, + ) + llm = create_llm_service(user_config, correlation_id=mps_correlation_id) inference_llm = None # Stamp the providers/models actually resolved for this run onto @@ -695,7 +718,10 @@ async def _run_pipeline( # Create a separate LLM instance for the voicemail sub-pipeline # (can't share with main pipeline as it would mess up frame linking) if voicemail_config.get("use_workflow_llm", True): - voicemail_llm = create_llm_service(user_config) + voicemail_llm = create_llm_service( + user_config, + correlation_id=mps_correlation_id, + ) else: voicemail_llm = create_llm_service_from_provider( provider=voicemail_config.get("provider", "openai"), diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index 8ed96e40..ec5e9911 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None: def create_stt_service( - user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None + user_config, + audio_config: "AudioConfig", + keyterms: list[str] | None = None, + correlation_id: str | None = None, ): """Create and return appropriate STT service based on user configuration @@ -160,6 +163,7 @@ def create_stt_service( return DograhSTTService( base_url=base_url, api_key=user_config.stt.api_key, + correlation_id=correlation_id, settings=DograhSTTSettings( model=user_config.stt.model, language=language, @@ -286,7 +290,9 @@ def create_stt_service( ) -def create_tts_service(user_config, audio_config: "AudioConfig"): +def create_tts_service( + user_config, audio_config: "AudioConfig", correlation_id: str | None = None +): """Create and return appropriate TTS service based on user configuration Args: @@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): return DograhTTSService( base_url=base_url, api_key=user_config.tts.api_key, + correlation_id=correlation_id, settings=DograhTTSSettings( model=user_config.tts.model, voice=user_config.tts.voice, @@ -564,6 +571,7 @@ def create_llm_service_from_provider( model: str, api_key: str | None, *, + correlation_id: str | None = None, base_url: str | None = None, endpoint: str | None = None, aws_access_key: str | None = None, @@ -637,6 +645,7 @@ def create_llm_service_from_provider( return DograhLLMService( base_url=f"{MPS_API_URL}/api/v1/llm", api_key=api_key, + correlation_id=correlation_id, settings=OpenAILLMSettings(model=model), ) elif provider == ServiceProviders.AWS_BEDROCK.value: @@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"): ) -def create_llm_service(user_config): +def create_llm_service(user_config, correlation_id: str | None = None): """Create and return appropriate LLM service based on user configuration.""" provider = user_config.llm.provider model = user_config.llm.model @@ -880,4 +889,10 @@ def create_llm_service(user_config): elif provider == ServiceProviders.SARVAM.value: kwargs["temperature"] = user_config.llm.temperature - return create_llm_service_from_provider(provider, model, api_key, **kwargs) + return create_llm_service_from_provider( + provider, + model, + api_key, + correlation_id=correlation_id, + **kwargs, + ) diff --git a/api/services/pricing/README.md b/api/services/pricing/README.md deleted file mode 100644 index 4f834c28..00000000 --- a/api/services/pricing/README.md +++ /dev/null @@ -1,76 +0,0 @@ -# Pricing Module - -This module contains pricing models and registries for different AI services used in workflow cost calculations. - -## Structure - -``` -pricing/ -├── __init__.py # Main module exports -├── models.py # Base pricing model classes -├── llm.py # LLM pricing configurations -├── tts.py # TTS pricing configurations -├── stt.py # STT pricing configurations -├── registry.py # Combined pricing registry -└── README.md # This file -``` - -## Pricing Models - -### TokenPricingModel -Used for LLM services that charge based on tokens: -- `prompt_token_price`: Cost per prompt token -- `completion_token_price`: Cost per completion token -- `cache_read_discount`: Discount for cache read tokens (default 50%) -- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%) - -### CharacterPricingModel -Used for TTS services that charge based on character count: -- `character_price`: Cost per character - -### TimePricingModel -Used for STT services that charge based on time: -- `second_price`: Cost per second - -## Adding New Pricing - -### Adding a New LLM Model -Edit `llm.py` and add the model to the appropriate provider: - -```python -ServiceProviders.OPENAI: { - "new-model": TokenPricingModel( - prompt_token_price=Decimal("2.00") / 1000000, - completion_token_price=Decimal("8.00") / 1000000, - ), - # ... existing models -} -``` - -### Adding a New Provider -1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py) -2. The registry will automatically include them - -### Adding a New Service Type -1. Create a new pricing file (e.g., `image.py`) -2. Define the pricing models -3. Import and add to `registry.py` - -## Usage - -The pricing registry is automatically imported and used by the cost calculator: - -```python -from api.services.pricing import PRICING_REGISTRY -from api.services.workflow.cost_calculator import cost_calculator - -# The cost calculator uses the pricing registry automatically -result = cost_calculator.calculate_total_cost(usage_info) -``` - -## Maintenance - -- Update pricing when providers change their rates -- All prices should use `Decimal` for precision -- Include comments with current pricing from provider documentation -- Test changes with existing test suite \ No newline at end of file diff --git a/api/services/pricing/__init__.py b/api/services/pricing/__init__.py deleted file mode 100644 index 1fa0eedf..00000000 --- a/api/services/pricing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Pricing module for workflow cost calculation. - -This module contains pricing models and registries for different AI services. -""" - -from .registry import PRICING_REGISTRY - -__all__ = ["PRICING_REGISTRY"] diff --git a/api/services/pricing/cost_calculator.py b/api/services/pricing/cost_calculator.py deleted file mode 100644 index 14344752..00000000 --- a/api/services/pricing/cost_calculator.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -Cost Calculator for Workflow Runs - -This module provides a comprehensive cost calculation system for workflow runs based on usage metrics -from different AI service providers (OpenAI, Groq, Deepgram, etc.). - -Features: -- Token-based pricing for LLM services with cache optimization support -- Character-based pricing for TTS services -- Time-based pricing for STT services -- Configurable pricing models that can be updated -- Support for multiple providers and models -- Automatic provider inference from model names -- JSON serialization support for database storage - -Usage: - from api.tasks.cost_calculator import cost_calculator - - usage_info = { - "llm": { - "processor_name|||gpt-4o": { - "prompt_tokens": 1000, - "completion_tokens": 500, - "total_tokens": 1500, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0 - } - }, - "tts": { - "processor_name|||aura-2-helena-en": 2000 # character count - } - } - - cost_breakdown = cost_calculator.calculate_total_cost(usage_info) - print(f"Total cost: ${cost_breakdown['total']:.6f}") -""" - -from decimal import Decimal -from typing import Any, Dict, Optional, Tuple - -from api.services.configuration.registry import ServiceProviders -from api.services.pricing import PRICING_REGISTRY -from api.services.pricing.models import ( - PricingModel, -) - - -class CostCalculator: - """Main cost calculator class""" - - def __init__(self, pricing_registry: Dict = None): - self.pricing_registry = pricing_registry or PRICING_REGISTRY - - def get_pricing_model( - self, service_type: str, provider: str, model: str - ) -> Optional[PricingModel]: - """Get pricing model for a specific service, provider, and model""" - try: - service_pricing = self.pricing_registry.get(service_type, {}) - - # Try to get pricing for the specific provider - provider_pricing = service_pricing.get(provider, {}) - pricing_model = provider_pricing.get(model) or provider_pricing.get( - "default" - ) - - if pricing_model: - return pricing_model - - # If not found, try the "default" provider for this service type - default_provider_pricing = service_pricing.get("default", {}) - return default_provider_pricing.get(model) or default_provider_pricing.get( - "default" - ) - - except (KeyError, AttributeError): - return None - - def calculate_llm_cost( - self, provider: str, model: str, usage: Dict[str, int] - ) -> Decimal: - """Calculate cost for LLM usage""" - pricing_model = self.get_pricing_model("llm", provider, model) - if not pricing_model: - return Decimal("0") - return pricing_model.calculate_cost(usage) - - def calculate_tts_cost( - self, provider: str, model: str, character_count: int - ) -> Decimal: - """Calculate cost for TTS usage""" - pricing_model = self.get_pricing_model("tts", provider, model) - if not pricing_model: - return Decimal("0") - return pricing_model.calculate_cost(character_count) - - def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal: - """Calculate cost for STT usage""" - pricing_model = self.get_pricing_model("stt", provider, model) - if not pricing_model: - return Decimal("0") - return pricing_model.calculate_cost(seconds) - - def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]: - llm_cost_total = Decimal("0") - tts_cost_total = Decimal("0") - stt_cost_total = Decimal("0") - - # Calculate LLM costs - llm_usage = usage_info.get("llm", {}) - for key, usage in llm_usage.items(): - processor, model = self._parse_key(key) - # Try to determine provider from processor name or model - provider = self._infer_provider_from_model(model, "llm") - cost = self.calculate_llm_cost(provider, model, usage) - llm_cost_total += cost - - # Calculate TTS costs - tts_usage = usage_info.get("tts", {}) - for key, character_count in tts_usage.items(): - processor, model = self._parse_key(key) - # Handle the case where model is "None" - infer from processor - if model.lower() in ["none", "null", ""]: - provider = self._infer_provider_from_processor(processor, "tts") - model = "default" # Use default model for the provider - else: - provider = self._infer_provider_from_model(model, "tts") - cost = self.calculate_tts_cost(provider, model, character_count) - tts_cost_total += cost - - # Calculate STT costs from explicit stt usage - stt_usage = usage_info.get("stt", {}) - for key, seconds in stt_usage.items(): - processor, model = self._parse_key(key) - provider = self._infer_provider_from_model(model, "stt") - cost = self.calculate_stt_cost(provider, model, seconds) - stt_cost_total += cost - - total_cost = llm_cost_total + tts_cost_total + stt_cost_total - - return { - "llm_cost": float(llm_cost_total), - "tts_cost": float(tts_cost_total), - "stt_cost": float(stt_cost_total), - "total": float(total_cost), - } - - def _parse_key(self, key) -> Tuple[str, str]: - """Parse key which is in format 'processor|||model'""" - if isinstance(key, str) and "|||" in key: - parts = key.split("|||", 1) - return parts[0], parts[1] - else: - # Fallback for backwards compatibility or malformed keys - return str(key), "unknown" - - def _infer_provider_from_model(self, model: str, service_type: str) -> str: - """Infer provider from model name""" - if not model: - return "unknown" - - model_lower = model.lower() - - # OpenAI models - if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]): - return ServiceProviders.OPENAI - - # Groq models - if any(keyword in model_lower for keyword in ["groq"]): - return ServiceProviders.GROQ - - # Elevenlabs models - if any(keyword in model_lower for keyword in ["eleven"]): - return ServiceProviders.ELEVENLABS - - # Deepgram models - if any( - keyword in model_lower - for keyword in ["deepgram", "nova", "phonecall", "general"] - ): - return ServiceProviders.DEEPGRAM - - # Default to first available provider for the service type - service_providers = self.pricing_registry.get(service_type, {}) - if service_providers: - return list(service_providers.keys())[0] - - return "unknown" - - def _infer_provider_from_processor(self, processor: str, service_type: str) -> str: - """Infer provider from processor name""" - if not processor: - return "unknown" - - processor_lower = processor.lower() - - # OpenAI processors - if any(keyword in processor_lower for keyword in ["openai", "gpt"]): - return ServiceProviders.OPENAI - - # Groq processors - if any(keyword in processor_lower for keyword in ["groq"]): - return ServiceProviders.GROQ - - # Deepgram processors - if any(keyword in processor_lower for keyword in ["deepgram"]): - return ServiceProviders.DEEPGRAM - - # Default to first available provider for the service type - service_providers = self.pricing_registry.get(service_type, {}) - if service_providers: - return list(service_providers.keys())[0] - - return "unknown" - - def update_pricing( - self, service_type: str, provider: str, model: str, pricing_model: PricingModel - ): - """Update pricing for a specific service/provider/model combination""" - if service_type not in self.pricing_registry: - self.pricing_registry[service_type] = {} - if provider not in self.pricing_registry[service_type]: - self.pricing_registry[service_type][provider] = {} - self.pricing_registry[service_type][provider][model] = pricing_model - - -# Global cost calculator instance -cost_calculator = CostCalculator() diff --git a/api/services/pricing/embeddings.py b/api/services/pricing/embeddings.py deleted file mode 100644 index a58a8caa..00000000 --- a/api/services/pricing/embeddings.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Embeddings pricing models for different providers. - -Prices are per token for embedding models. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import PricingModel - - -class EmbeddingPricingModel(PricingModel): - """Pricing model for token-based embedding services.""" - - def __init__(self, token_price: Decimal): - """Initialize with price per token. - - Args: - token_price: Cost per token for embedding - """ - self.token_price = token_price - - def calculate_cost(self, token_count: int) -> Decimal: - """Calculate cost for embedding token usage.""" - return Decimal(token_count) * self.token_price - - -# Embeddings pricing registry -EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = { - ServiceProviders.OPENAI: { - "text-embedding-3-small": EmbeddingPricingModel( - token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens - ), - "text-embedding-3-large": EmbeddingPricingModel( - token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens - ), - "text-embedding-ada-002": EmbeddingPricingModel( - token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy) - ), - }, -} diff --git a/api/services/pricing/llm.py b/api/services/pricing/llm.py deleted file mode 100644 index addb59bc..00000000 --- a/api/services/pricing/llm.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -LLM pricing models for different providers. - -Prices are per 1000 tokens for most models, with some newer models priced per million tokens. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import TokenPricingModel - -# LLM pricing registry -LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = { - ServiceProviders.OPENAI: { - "gpt-3.5-turbo": TokenPricingModel( - prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens - completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens - ), - "gpt-4": TokenPricingModel( - prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens - completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens - ), - "gpt-4.1": TokenPricingModel( - prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens - completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens - ), - "gpt-4.1-mini": TokenPricingModel( - prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens - completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens - ), - "gpt-4.1-nano": TokenPricingModel( - prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens - completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens - ), - "gpt-4.5-preview": TokenPricingModel( - prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens - completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens - ), - "gpt-4o": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED - completion_token_price=Decimal("10.00") - / 1000000, # $10.00 per 1M tokens - FIXED - ), - "gpt-4o-audio-preview": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - ), - "gpt-4o-realtime-preview": TokenPricingModel( - prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens - completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens - ), - "gpt-4o-mini": TokenPricingModel( - prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens - completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - ), - "gpt-4o-mini-audio-preview": TokenPricingModel( - prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens - completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - ), - "gpt-4o-mini-realtime-preview": TokenPricingModel( - prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens - ), - "gpt-4o-search-preview": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - ), - "gpt-4o-mini-search-preview": TokenPricingModel( - prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens - completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - ), - "o1": TokenPricingModel( - prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens - completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens - ), - "o1-pro": TokenPricingModel( - prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens - completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens - ), - "o1-mini": TokenPricingModel( - prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens - completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens - ), - "o3": TokenPricingModel( - prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens - ), - "o3-mini": TokenPricingModel( - prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens - completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens - ), - "o4-mini": TokenPricingModel( - prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens - completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens - ), - "computer-use-preview": TokenPricingModel( - prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens - completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens - ), - "gpt-image-1": TokenPricingModel( - prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens - completion_token_price=Decimal("0") / 1000000, # No output pricing shown - ), - "codex-mini-latest": TokenPricingModel( - prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens - completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens - ), - # Transcription models - "gpt-4o-transcribe": TokenPricingModel( - prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens - ), - "gpt-4o-mini-transcribe": TokenPricingModel( - prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens - completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens - ), - # TTS models with token-based pricing - "gpt-4o-mini-tts": TokenPricingModel( - prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens - completion_token_price=Decimal("0") - / 1000000, # No completion tokens for TTS - ), - }, - ServiceProviders.GROQ: { - "llama-3.3-70b-versatile": TokenPricingModel( - prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens - completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens - ), - "deepseek-r1-distill-llama-70b": TokenPricingModel( - prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing - completion_token_price=Decimal("0.00079") / 1000, - ), - }, - ServiceProviders.AZURE: { - "gpt-4.1-mini": TokenPricingModel( - prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens - completion_token_price=Decimal("8.80") - / 1000000, # $1.60 per 1M tokens if using data zone - ) - }, -} diff --git a/api/services/pricing/models.py b/api/services/pricing/models.py deleted file mode 100644 index 58e197ac..00000000 --- a/api/services/pricing/models.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Base pricing models for different service types. -""" - -from decimal import Decimal -from enum import Enum -from typing import Any, Dict - - -class CostType(Enum): - LLM_TOKENS = "llm_tokens" - TTS_CHARACTERS = "tts_characters" - STT_SECONDS = "stt_seconds" - - -class PricingModel: - """Base class for pricing models""" - - def calculate_cost(self, usage: Any) -> Decimal: - """Calculate cost based on usage""" - raise NotImplementedError - - -class TokenPricingModel(PricingModel): - """Pricing model for token-based services (LLM)""" - - def __init__( - self, - prompt_token_price: Decimal, - completion_token_price: Decimal, - cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads - cache_creation_multiplier: Decimal = Decimal( - "1.25" - ), # 25% premium for cache creation - ): - self.prompt_token_price = prompt_token_price - self.completion_token_price = completion_token_price - self.cache_read_discount = cache_read_discount - self.cache_creation_multiplier = cache_creation_multiplier - - def calculate_cost(self, usage: Dict[str, int]) -> Decimal: - """Calculate cost for LLM token usage""" - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - cache_read_tokens = usage.get("cache_read_input_tokens") or 0 - cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0 - - # Base cost - prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price - completion_cost = Decimal(completion_tokens) * self.completion_token_price - - # Cache adjustments - cache_read_savings = ( - Decimal(cache_read_tokens) - * self.prompt_token_price - * self.cache_read_discount - ) - cache_creation_premium = ( - Decimal(cache_creation_tokens) - * self.prompt_token_price - * (self.cache_creation_multiplier - 1) - ) - - total_cost = ( - prompt_cost + completion_cost - cache_read_savings + cache_creation_premium - ) - return max(total_cost, Decimal("0")) # Ensure non-negative - - -class CharacterPricingModel(PricingModel): - """Pricing model for character-based services (TTS)""" - - def __init__(self, character_price: Decimal): - self.character_price = character_price - - def calculate_cost(self, character_count: int) -> Decimal: - """Calculate cost for TTS character usage""" - return Decimal(character_count) * self.character_price - - -class TimePricingModel(PricingModel): - """Pricing model for time-based services (STT)""" - - def __init__(self, second_price: Decimal): - self.second_price = second_price - - def calculate_cost(self, seconds: float) -> Decimal: - """Calculate cost for STT time usage""" - return Decimal(str(seconds)) * self.second_price diff --git a/api/services/pricing/registry.py b/api/services/pricing/registry.py deleted file mode 100644 index 294a94a2..00000000 --- a/api/services/pricing/registry.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -Main pricing registry that combines all service type pricing models. -""" - -from typing import Dict - -from .embeddings import EMBEDDINGS_PRICING -from .llm import LLM_PRICING -from .stt import STT_PRICING -from .tts import TTS_PRICING - -# Combined pricing registry for all service types -PRICING_REGISTRY: Dict = { - "llm": LLM_PRICING, - "tts": TTS_PRICING, - "stt": STT_PRICING, - "embeddings": EMBEDDINGS_PRICING, -} diff --git a/api/services/pricing/run_usage_response.py b/api/services/pricing/run_usage_response.py deleted file mode 100644 index a1f85a47..00000000 --- a/api/services/pricing/run_usage_response.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Format workflow run usage for public API responses.""" - - -def format_public_usage_info(usage_info: dict | None) -> dict | None: - if not usage_info: - return None - - return { - "llm": usage_info.get("llm") or {}, - "tts": usage_info.get("tts") or {}, - "stt": usage_info.get("stt") or {}, - "call_duration_seconds": usage_info.get("call_duration_seconds"), - } diff --git a/api/services/pricing/stt.py b/api/services/pricing/stt.py deleted file mode 100644 index ca00ff4c..00000000 --- a/api/services/pricing/stt.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -STT (Speech-to-Text) pricing models for different providers. - -Prices are per second for STT services. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import TimePricingModel - -# STT pricing registry -STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = { - ServiceProviders.DEEPGRAM: { - "nova-3-general": TimePricingModel(Decimal("0.0077") / 60), - "nova-2": TimePricingModel(Decimal("0.0058") / 60), - "default": TimePricingModel(Decimal("0.0077") / 60), - }, - ServiceProviders.OPENAI: { - "gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60), - "default": TimePricingModel(Decimal("0.015") / 60), - }, - "default": {"default": TimePricingModel(Decimal("0.0077") / 60)}, -} diff --git a/api/services/pricing/tts.py b/api/services/pricing/tts.py deleted file mode 100644 index 7485cc7f..00000000 --- a/api/services/pricing/tts.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -TTS (Text-to-Speech) pricing models for different providers. - -Prices are per character for TTS services. -""" - -from decimal import Decimal -from typing import Dict - -from api.services.configuration.registry import ServiceProviders - -from .models import CharacterPricingModel - -# TTS pricing registry -TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = { - ServiceProviders.OPENAI: { - "gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000), - "default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000), - }, - ServiceProviders.DEEPGRAM: { - "aura-2": CharacterPricingModel(Decimal("0.030") / 1_000), - "aura-1": CharacterPricingModel(Decimal("0.015") / 1_000), - "default": CharacterPricingModel(Decimal("0.030") / 1_000), - }, - ServiceProviders.ELEVENLABS: { - # 6400 usd per 250*1e6 characters - "default": CharacterPricingModel(Decimal("0.0256") / 1_000) - }, - "default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)}, -} diff --git a/api/services/pricing/workflow_run_cost.py b/api/services/pricing/workflow_run_cost.py deleted file mode 100644 index 6d6010c3..00000000 --- a/api/services/pricing/workflow_run_cost.py +++ /dev/null @@ -1,230 +0,0 @@ -from decimal import Decimal - -from loguru import logger - -from api.db import db_client -from api.enums import WorkflowRunMode -from api.services.pricing.cost_calculator import cost_calculator -from api.services.telephony.factory import get_telephony_provider_for_run - - -async def _fetch_telephony_cost(workflow_run) -> dict | None: - """Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None.""" - if ( - workflow_run.mode - not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value] - or not workflow_run.cost_info - ): - return None - - call_id = workflow_run.cost_info.get("call_id") - if not call_id: - logger.warning(f"call_id not found in cost_info") - return None - - provider_name = workflow_run.mode.lower() if workflow_run.mode else "" - - workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) - if not workflow: - logger.warning("Workflow not found for workflow run") - raise Exception("Workflow not found") - - provider = await get_telephony_provider_for_run( - workflow_run, workflow.organization_id - ) - call_cost_info = await provider.get_call_cost(call_id) - - if call_cost_info.get("status") == "error": - logger.error( - f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}" - ) - return None - - cost_usd = call_cost_info.get("cost_usd", 0.0) - logger.info( - f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}" - ) - return {"cost_usd": cost_usd, "provider_name": provider_name} - - -async def _update_organization_usage( - org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None -) -> None: - """Update organization usage after a workflow run.""" - org_id = org.id - await db_client.update_usage_after_run( - org_id, dograh_tokens, duration_seconds, charge_usd - ) - if charge_usd is not None: - logger.info( - f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}" - ) - else: - logger.info( - f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}" - ) - - -async def _get_pricing_organization(workflow_run): - workflow = getattr(workflow_run, "workflow", None) - organization_id = getattr(workflow, "organization_id", None) - if organization_id is None and workflow and workflow.user: - organization_id = workflow.user.selected_organization_id - if organization_id is None: - return None - return await db_client.get_organization_by_id(organization_id) - - -async def _build_usage_cost_snapshot( - usage_info: dict | None, - *, - workflow_run=None, - include_telephony_cost: bool = False, - organization=None, - calculated_at: str | None = None, -) -> dict | None: - if not usage_info: - logger.warning("No usage info available for workflow run") - return None - - cost_breakdown = cost_calculator.calculate_total_cost(usage_info) - - if include_telephony_cost and workflow_run is not None: - try: - telephony_cost = await _fetch_telephony_cost(workflow_run) - if telephony_cost: - telephony_cost_usd = telephony_cost["cost_usd"] - provider_name = telephony_cost["provider_name"] - cost_breakdown["telephony_call"] = telephony_cost_usd - cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd - cost_breakdown["total"] = ( - float(cost_breakdown["total"]) + telephony_cost_usd - ) - except Exception as e: - logger.error(f"Failed to fetch telephony call cost: {e}") - # Don't fail the whole cost calculation if telephony API fails - - total_cost_usd = Decimal(str(cost_breakdown["total"])) - dograh_tokens = float(total_cost_usd * Decimal("100")) - - if organization is None and workflow_run is not None: - organization = await _get_pricing_organization(workflow_run) - - charge_usd = None - if organization and organization.price_per_second_usd: - duration_seconds = usage_info.get("call_duration_seconds", 0) - charge_usd = float( - Decimal(str(duration_seconds)) - * Decimal(str(organization.price_per_second_usd)) - ) - - cost_info = { - "cost_breakdown": cost_breakdown, - "total_cost_usd": float(total_cost_usd), - "dograh_token_usage": dograh_tokens, - "calculated_at": calculated_at - or (workflow_run.created_at.isoformat() if workflow_run is not None else None), - "call_duration_seconds": usage_info.get("call_duration_seconds", 0), - } - - if charge_usd is not None: - cost_info["charge_usd"] = charge_usd - cost_info["price_per_second_usd"] = organization.price_per_second_usd - - return cost_info - - -async def build_workflow_run_cost_info(workflow_run) -> dict | None: - cost_info = await _build_usage_cost_snapshot( - workflow_run.usage_info, - workflow_run=workflow_run, - include_telephony_cost=True, - calculated_at=workflow_run.created_at.isoformat(), - ) - if cost_info is None: - return None - return { - **(workflow_run.cost_info or {}), - **cost_info, - } - - -async def save_workflow_run_cost_info( - workflow_run_id: int, cost_info: dict | None -) -> None: - if cost_info is None: - return - await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info) - - -async def apply_workflow_run_usage_to_organization( - workflow_run, cost_info: dict | None -) -> None: - if cost_info is None: - return - - org = await _get_pricing_organization(workflow_run) - if not org: - return - - await _update_organization_usage( - org, - float(cost_info.get("dograh_token_usage") or 0), - float(cost_info.get("call_duration_seconds") or 0), - cost_info.get("charge_usd"), - ) - - -async def apply_usage_delta_to_organization( - workflow_run, usage_info: dict | None -) -> dict | None: - org = await _get_pricing_organization(workflow_run) - if not org: - return None - - cost_info = await _build_usage_cost_snapshot(usage_info, organization=org) - if cost_info is None: - return None - - await _update_organization_usage( - org, - float(cost_info.get("dograh_token_usage") or 0), - float(cost_info.get("call_duration_seconds") or 0), - cost_info.get("charge_usd"), - ) - return cost_info - - -async def calculate_workflow_run_cost(workflow_run_id: int): - logger.debug("Calculating cost for workflow run") - - workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) - if not workflow_run: - logger.warning("Workflow run not found") - return - - try: - cost_info = await build_workflow_run_cost_info(workflow_run) - if cost_info is None: - return - - await save_workflow_run_cost_info(workflow_run_id, cost_info) - - try: - await apply_workflow_run_usage_to_organization(workflow_run, cost_info) - except Exception as e: - org = await _get_pricing_organization(workflow_run) - if org: - logger.error( - f"Failed to update organization usage for org {org.id}: {e}" - ) - else: - logger.error(f"Failed to update organization usage: {e}") - # Don't fail the whole cost calculation if usage update fails - - logger.info( - f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)" - ) - except Exception as e: - logger.error(f"Error calculating cost for workflow run: {e}") - raise diff --git a/api/services/reports/run_report.py b/api/services/reports/run_report.py index b84a6f96..a5e64819 100644 --- a/api/services/reports/run_report.py +++ b/api/services/reports/run_report.py @@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO: for run in runs: initial = run.initial_context or {} gathered = run.gathered_context or {} - cost = run.cost_info or {} + usage = run.usage_info or {} call_tags = gathered.get("call_tags", []) if isinstance(call_tags, list): @@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO: run.created_at.isoformat() if run.created_at else "", initial.get("phone_number", ""), gathered.get("mapped_call_disposition", ""), - cost.get("call_duration_seconds", ""), + usage.get("call_duration_seconds", ""), ] extracted = gathered.get("extracted_variables", {}) diff --git a/api/services/telephony/providers/vonage/routes.py b/api/services/telephony/providers/vonage/routes.py index a4cca35d..c862e745 100644 --- a/api/services/telephony/providers/vonage/routes.py +++ b/api/services/telephony/providers/vonage/routes.py @@ -66,34 +66,6 @@ async def handle_vonage_events( logger.error(f"[run {workflow_run_id}] Workflow run not found") return {"status": "error", "message": "Workflow run not found"} - # For a completed call that includes cost info, capture it immediately - if event_data.get("status") == "completed": - # Vonage sometimes includes price info in the webhook - if "price" in event_data or "rate" in event_data: - try: - if workflow_run.cost_info: - # Store immediate cost info if available - cost_info = workflow_run.cost_info.copy() - if "price" in event_data: - cost_info["vonage_webhook_price"] = float(event_data["price"]) - if "rate" in event_data: - cost_info["vonage_webhook_rate"] = float(event_data["rate"]) - if "duration" in event_data: - cost_info["vonage_webhook_duration"] = int( - event_data["duration"] - ) - - await db_client.update_workflow_run( - run_id=workflow_run_id, cost_info=cost_info - ) - logger.info( - f"[run {workflow_run_id}] Captured Vonage cost info from webhook" - ) - except Exception as e: - logger.error( - f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}" - ) - # Get workflow and provider workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id) if not workflow: diff --git a/api/services/workflow/pipecat_engine.py b/api/services/workflow/pipecat_engine.py index cea1d21f..a0d67947 100644 --- a/api/services/workflow/pipecat_engine.py +++ b/api/services/workflow/pipecat_engine.py @@ -35,6 +35,7 @@ import asyncio from loguru import logger +from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY from api.services.workflow import pipecat_engine_callbacks as engine_callbacks from api.services.workflow.mcp_tool_session import McpToolSession from api.services.workflow.pipecat_engine_context_composer import ( @@ -382,6 +383,9 @@ class PipecatEngine: embeddings_provider=self._embeddings_provider, embeddings_endpoint=self._embeddings_endpoint, embeddings_api_version=self._embeddings_api_version, + correlation_id=self._call_context_vars.get( + MPS_CORRELATION_ID_CONTEXT_KEY + ), tracing_context=self._get_otel_context(), ) diff --git a/api/services/workflow/run_usage_response.py b/api/services/workflow/run_usage_response.py new file mode 100644 index 00000000..c289e565 --- /dev/null +++ b/api/services/workflow/run_usage_response.py @@ -0,0 +1,41 @@ +"""Format workflow run usage for public API responses.""" + + +def format_public_usage_info(usage_info: dict | None) -> dict | None: + if not usage_info: + return None + + return { + "llm": usage_info.get("llm") or {}, + "tts": usage_info.get("tts") or {}, + "stt": usage_info.get("stt") or {}, + "call_duration_seconds": usage_info.get("call_duration_seconds"), + } + + +def format_public_cost_info( + cost_info: dict | None, usage_info: dict | None +) -> dict | None: + """Return the legacy response shape without doing local cost accounting.""" + duration = None + if usage_info and usage_info.get("call_duration_seconds") is not None: + duration = int(round(usage_info.get("call_duration_seconds") or 0)) + elif cost_info and cost_info.get("call_duration_seconds") is not None: + duration = int(round(cost_info.get("call_duration_seconds") or 0)) + + dograh_token_usage = 0 + if cost_info: + if "dograh_token_usage" in cost_info: + dograh_token_usage = cost_info.get("dograh_token_usage") or 0 + elif "total_cost_usd" in cost_info: + dograh_token_usage = round( + float(cost_info.get("total_cost_usd", 0)) * 100, 2 + ) + + if duration is None and dograh_token_usage == 0: + return None + + return { + "dograh_token_usage": dograh_token_usage, + "call_duration_seconds": duration, + } diff --git a/api/services/workflow/text_chat_runner.py b/api/services/workflow/text_chat_runner.py index 59073c80..7f6c5a0b 100644 --- a/api/services/workflow/text_chat_runner.py +++ b/api/services/workflow/text_chat_runner.py @@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn( if user_config.llm is None: raise ValueError("Text chat requires an LLM configuration") - llm = create_llm_service(user_config) + from api.services.managed_model_services import ( + MPS_CORRELATION_ID_CONTEXT_KEY, + ensure_mps_correlation_id, + ) + + base_initial_context = dict(workflow_run.initial_context or {}) + mps_correlation_id = await ensure_mps_correlation_id( + ai_model_config=user_config, + workflow_run_id=workflow_run_id, + initial_context=base_initial_context, + ) + + llm = create_llm_service(user_config, correlation_id=mps_correlation_id) inference_llm = llm runtime_configuration = { @@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn( "llm_model": user_config.llm.model, } initial_context = { - **(workflow_run.initial_context or {}), + **base_initial_context, "runtime_configuration": runtime_configuration, } + if mps_correlation_id: + initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id + await db_client.update_workflow_run( + workflow_run_id, + initial_context=initial_context, + ) workflow_graph = WorkflowGraph( ReactFlowDTO.model_validate(run_definition.workflow_json) diff --git a/api/services/workflow/text_chat_session_service.py b/api/services/workflow/text_chat_session_service.py index 53354d5f..81749960 100644 --- a/api/services/workflow/text_chat_session_service.py +++ b/api/services/workflow/text_chat_session_service.py @@ -4,17 +4,11 @@ from datetime import UTC, datetime from typing import Any from uuid import uuid4 -from loguru import logger - from api.db import db_client from api.db.models import WorkflowRunTextSessionModel from api.db.workflow_run_text_session_client import ( WorkflowRunTextSessionRevisionConflictError, ) -from api.services.pricing.workflow_run_cost import ( - apply_usage_delta_to_organization, - build_workflow_run_cost_info, -) from api.services.workflow.text_chat_logs import ( build_text_chat_realtime_feedback_events, ) @@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn( state=execution.state, is_completed=execution.is_completed, ) - workflow_run = await db_client.get_workflow_run_by_id(run_id) - if workflow_run: - try: - # Apply the per-turn delta so org usage tracks cumulative run cost - # without replaying the full session totals on every turn. - await apply_usage_delta_to_organization(workflow_run, execution.usage) - except Exception as e: - logger.error( - f"Failed to update organization usage for text chat run {run_id}: {e}" - ) - - cost_info = await build_workflow_run_cost_info(workflow_run) - if cost_info is not None: - await db_client.update_workflow_run(run_id, cost_info=cost_info) return await _reload_text_chat_session(run_id) diff --git a/api/services/workflow/tools/knowledge_base.py b/api/services/workflow/tools/knowledge_base.py index 6ce8f8c7..7b93aea7 100644 --- a/api/services/workflow/tools/knowledge_base.py +++ b/api/services/workflow/tools/knowledge_base.py @@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base( embeddings_provider: Optional[str] = None, embeddings_endpoint: Optional[str] = None, embeddings_api_version: Optional[str] = None, + correlation_id: Optional[str] = None, tracing_context=None, ) -> Dict[str, Any]: """Retrieve relevant information from the knowledge base using vector similarity search. @@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) # Create span with parent context @@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) # Add result metadata to span @@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) else: # Tracing is disabled - perform retrieval without tracing @@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base( embeddings_provider, embeddings_endpoint, embeddings_api_version, + correlation_id, ) @@ -220,6 +225,7 @@ async def _perform_retrieval( embeddings_provider: Optional[str] = None, embeddings_endpoint: Optional[str] = None, embeddings_api_version: Optional[str] = None, + correlation_id: Optional[str] = None, ) -> Dict[str, Any]: """Internal function to perform the actual retrieval operation. @@ -272,11 +278,20 @@ async def _perform_retrieval( api_version=embeddings_api_version or "2024-02-15-preview", ) else: + default_headers = None + if ( + embeddings_provider == ServiceProviders.DOGRAH.value + and correlation_id + ): + default_headers = { + "X-Dograh-Correlation-Id": correlation_id, + } embedding_service = OpenAIEmbeddingService( db_client=db_client, api_key=embeddings_api_key, model_id=embeddings_model or "text-embedding-3-small", base_url=embeddings_base_url, + default_headers=default_headers, ) results = await embedding_service.search_similar_chunks( diff --git a/api/services/workflow_run_billing.py b/api/services/workflow_run_billing.py new file mode 100644 index 00000000..ab8a3121 --- /dev/null +++ b/api/services/workflow_run_billing.py @@ -0,0 +1,111 @@ +"""Workflow-run billing hooks. + +Dograh does not rate or deduct credits locally. MPS owns credit accounting. +For hosted deployments, Dograh reports completed platform usage to MPS. +When a server-minted MPS correlation id exists, MPS uses model-service usage +as the canonical duration. Otherwise Dograh reports the completed run duration. +""" + +from typing import Any + +from loguru import logger + +from api.constants import DEPLOYMENT_MODE +from api.db import db_client +from api.services.managed_model_services import get_mps_correlation_id +from api.services.mps_service_key_client import mps_service_key_client + + +def _workflow_run_organization_id(workflow_run) -> int | None: + workflow = getattr(workflow_run, "workflow", None) + return getattr(workflow, "organization_id", None) + + +def _duration_seconds_from_usage_info(workflow_run) -> float | None: + usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {} + duration = usage_info.get("call_duration_seconds") + try: + duration_seconds = float(duration) + except (TypeError, ValueError): + return None + + return duration_seconds if duration_seconds > 0 else None + + +async def _organization_uses_mps_billing_v2(organization_id: int) -> bool: + account = await mps_service_key_client.get_billing_account_status( + organization_id=organization_id + ) + return bool(account and account.get("billing_mode") == "v2") + + +async def report_workflow_run_platform_usage(workflow_run) -> None: + """Report hosted platform usage for a completed workflow run to MPS.""" + if DEPLOYMENT_MODE == "oss": + return + + if not getattr(workflow_run, "is_completed", False): + return + + organization_id = _workflow_run_organization_id(workflow_run) + if organization_id is None: + logger.warning( + "Skipping platform usage report for workflow run {}: no organization_id", + workflow_run.id, + ) + return + + correlation_id = get_mps_correlation_id( + getattr(workflow_run, "initial_context", None) + ) + duration_seconds = ( + None if correlation_id else _duration_seconds_from_usage_info(workflow_run) + ) + if not correlation_id and duration_seconds is None: + logger.warning( + "Skipping platform usage report for workflow run {}: no billable duration", + workflow_run.id, + ) + return + + try: + if not await _organization_uses_mps_billing_v2(organization_id): + return + + result = await mps_service_key_client.report_platform_usage( + organization_id=organization_id, + correlation_id=correlation_id, + duration_seconds=duration_seconds, + workflow_run_id=workflow_run.id, + metadata={ + "source": "workflow_run_completion", + "workflow_id": getattr(workflow_run, "workflow_id", None), + "duration_source": ( + "mps_correlation" if correlation_id else "dograh_usage_info" + ), + }, + ) + logger.info( + "Reported platform usage for workflow run {} to MPS: {}", + workflow_run.id, + result, + ) + except Exception as e: + logger.error( + "Failed to report platform usage for workflow run {}: {}", + workflow_run.id, + e, + ) + + +async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None: + """Load a completed workflow run and report platform usage to MPS.""" + workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) + if not workflow_run: + logger.warning( + "Skipping platform usage report: workflow run {} not found", + workflow_run_id, + ) + return + + await report_workflow_run_platform_usage(workflow_run) diff --git a/api/tasks/arq.py b/api/tasks/arq.py index a948a578..442114e6 100644 --- a/api/tasks/arq.py +++ b/api/tasks/arq.py @@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import ( ) from api.tasks.knowledge_base_processing import process_knowledge_base_document from api.tasks.run_integrations import run_integrations_post_workflow_run -from api.tasks.s3_upload import ( - process_workflow_completion, - upload_voicemail_audio_to_s3, -) +from api.tasks.s3_upload import upload_voicemail_audio_to_s3 +from api.tasks.workflow_completion import process_workflow_completion class WorkerSettings: diff --git a/api/tasks/knowledge_base_processing.py b/api/tasks/knowledge_base_processing.py index f496ac0e..a6ca0d6d 100644 --- a/api/tasks/knowledge_base_processing.py +++ b/api/tasks/knowledge_base_processing.py @@ -166,18 +166,22 @@ async def process_knowledge_base_document( user_id=document.created_by, organization_id=document.organization_id, ) - user_config = resolved_config.effective - if user_config.embeddings: - embeddings_provider = getattr(user_config.embeddings, "provider", None) - embeddings_api_key = user_config.embeddings.api_key - embeddings_model = user_config.embeddings.model + effective_config = resolved_config.effective + if effective_config.embeddings: + embeddings_provider = getattr( + effective_config.embeddings, "provider", None + ) + embeddings_api_key = effective_config.embeddings.api_key + embeddings_model = effective_config.embeddings.model embeddings_base_url = apply_managed_embeddings_base_url( provider=embeddings_provider, - base_url=getattr(user_config.embeddings, "base_url", None), + base_url=getattr(effective_config.embeddings, "base_url", None), + ) + embeddings_endpoint = getattr( + effective_config.embeddings, "endpoint", None ) - embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None) embeddings_api_version = getattr( - user_config.embeddings, "api_version", None + effective_config.embeddings, "api_version", None ) logger.info( f"Using user embeddings config: provider={embeddings_provider}, " diff --git a/api/tasks/s3_upload.py b/api/tasks/s3_upload.py index b2086c09..bbbc8bf4 100644 --- a/api/tasks/s3_upload.py +++ b/api/tasks/s3_upload.py @@ -1,13 +1,9 @@ import os -from typing import Optional from loguru import logger from pipecat.utils.run_context import set_current_run_id -from api.db import db_client -from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost -from api.services.storage import get_current_storage_backend, storage_fs -from api.tasks.run_integrations import run_integrations_post_workflow_run +from api.services.storage import storage_fs async def upload_voicemail_audio_to_s3( @@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3( logger.warning( f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}" ) - - -async def process_workflow_completion( - _ctx, - workflow_run_id: int, - audio_temp_path: Optional[str] = None, - transcript_temp_path: Optional[str] = None, -): - """Process workflow completion: upload artifacts and run integrations. - - This task combines audio upload, transcript upload, and webhook integrations - into a single sequential task to ensure integrations run after uploads complete. - - Args: - _ctx: ARQ context (unused) - workflow_run_id: The workflow run ID - audio_temp_path: Optional path to temp audio file - transcript_temp_path: Optional path to temp transcript file - """ - run_id = str(workflow_run_id) - set_current_run_id(run_id) - - logger.info(f"Processing workflow completion for run {workflow_run_id}") - - storage_backend = get_current_storage_backend() - - # Step 1: Upload audio if provided - if audio_temp_path: - try: - if os.path.exists(audio_temp_path): - file_size = os.path.getsize(audio_temp_path) - logger.debug(f"Audio file size: {file_size} bytes") - - recording_url = f"recordings/{workflow_run_id}.wav" - logger.info( - f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}" - ) - - await storage_fs.aupload_file(audio_temp_path, recording_url) - await db_client.update_workflow_run( - run_id=workflow_run_id, - recording_url=recording_url, - storage_backend=storage_backend.value, - ) - logger.info(f"Successfully uploaded audio: {recording_url}") - else: - logger.warning(f"Audio temp file not found: {audio_temp_path}") - except Exception as e: - logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}") - finally: - if audio_temp_path and os.path.exists(audio_temp_path): - try: - os.remove(audio_temp_path) - logger.debug(f"Cleaned up temp audio file: {audio_temp_path}") - except Exception as e: - logger.warning(f"Failed to clean up temp audio file: {e}") - - # Step 2: Upload transcript if provided - if transcript_temp_path: - try: - if os.path.exists(transcript_temp_path): - file_size = os.path.getsize(transcript_temp_path) - logger.debug(f"Transcript file size: {file_size} bytes") - - transcript_url = f"transcripts/{workflow_run_id}.txt" - logger.info( - f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}" - ) - - await storage_fs.aupload_file(transcript_temp_path, transcript_url) - await db_client.update_workflow_run( - run_id=workflow_run_id, - transcript_url=transcript_url, - storage_backend=storage_backend.value, - ) - logger.info(f"Successfully uploaded transcript: {transcript_url}") - else: - logger.warning( - f"Transcript temp file not found: {transcript_temp_path}" - ) - except Exception as e: - logger.error( - f"Error uploading transcript for workflow {workflow_run_id}: {e}" - ) - finally: - if transcript_temp_path and os.path.exists(transcript_temp_path): - try: - os.remove(transcript_temp_path) - logger.debug( - f"Cleaned up temp transcript file: {transcript_temp_path}" - ) - except Exception as e: - logger.warning(f"Failed to clean up temp transcript file: {e}") - - # Step 3: Run integrations including QA analysis (after uploads are complete) - try: - await run_integrations_post_workflow_run(_ctx, workflow_run_id) - except Exception as e: - logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}") - - # Step 4: Calculate cost after integrations (so QA token usage is included) - try: - await calculate_workflow_run_cost(workflow_run_id) - except Exception as e: - logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}") - - logger.info(f"Completed workflow completion processing for run {workflow_run_id}") diff --git a/api/tasks/workflow_completion.py b/api/tasks/workflow_completion.py new file mode 100644 index 00000000..ff0482d2 --- /dev/null +++ b/api/tasks/workflow_completion.py @@ -0,0 +1,121 @@ +import os +from typing import Optional + +from loguru import logger +from pipecat.utils.run_context import set_current_run_id + +from api.db import db_client +from api.services.storage import get_current_storage_backend, storage_fs +from api.services.workflow_run_billing import ( + report_completed_workflow_run_platform_usage, +) +from api.tasks.run_integrations import run_integrations_post_workflow_run + + +async def process_workflow_completion( + _ctx, + workflow_run_id: int, + audio_temp_path: Optional[str] = None, + transcript_temp_path: Optional[str] = None, +): + """Process workflow completion: upload artifacts and run integrations. + + This task combines audio upload, transcript upload, and webhook integrations + into a single sequential task to ensure integrations run after uploads complete. + + Args: + _ctx: ARQ context (unused) + workflow_run_id: The workflow run ID + audio_temp_path: Optional path to temp audio file + transcript_temp_path: Optional path to temp transcript file + """ + run_id = str(workflow_run_id) + set_current_run_id(run_id) + + logger.info(f"Processing workflow completion for run {workflow_run_id}") + + storage_backend = get_current_storage_backend() + + # Step 1: Upload audio if provided + if audio_temp_path: + try: + if os.path.exists(audio_temp_path): + file_size = os.path.getsize(audio_temp_path) + logger.debug(f"Audio file size: {file_size} bytes") + + recording_url = f"recordings/{workflow_run_id}.wav" + logger.info( + f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}" + ) + + await storage_fs.aupload_file(audio_temp_path, recording_url) + await db_client.update_workflow_run( + run_id=workflow_run_id, + recording_url=recording_url, + storage_backend=storage_backend.value, + ) + logger.info(f"Successfully uploaded audio: {recording_url}") + else: + logger.warning(f"Audio temp file not found: {audio_temp_path}") + except Exception as e: + logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}") + finally: + if audio_temp_path and os.path.exists(audio_temp_path): + try: + os.remove(audio_temp_path) + logger.debug(f"Cleaned up temp audio file: {audio_temp_path}") + except Exception as e: + logger.warning(f"Failed to clean up temp audio file: {e}") + + # Step 2: Upload transcript if provided + if transcript_temp_path: + try: + if os.path.exists(transcript_temp_path): + file_size = os.path.getsize(transcript_temp_path) + logger.debug(f"Transcript file size: {file_size} bytes") + + transcript_url = f"transcripts/{workflow_run_id}.txt" + logger.info( + f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}" + ) + + await storage_fs.aupload_file(transcript_temp_path, transcript_url) + await db_client.update_workflow_run( + run_id=workflow_run_id, + transcript_url=transcript_url, + storage_backend=storage_backend.value, + ) + logger.info(f"Successfully uploaded transcript: {transcript_url}") + else: + logger.warning( + f"Transcript temp file not found: {transcript_temp_path}" + ) + except Exception as e: + logger.error( + f"Error uploading transcript for workflow {workflow_run_id}: {e}" + ) + finally: + if transcript_temp_path and os.path.exists(transcript_temp_path): + try: + os.remove(transcript_temp_path) + logger.debug( + f"Cleaned up temp transcript file: {transcript_temp_path}" + ) + except Exception as e: + logger.warning(f"Failed to clean up temp transcript file: {e}") + + # Step 3: Run integrations including QA analysis (after uploads are complete) + try: + await run_integrations_post_workflow_run(_ctx, workflow_run_id) + except Exception as e: + logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}") + + # Step 4: Notify MPS after completion. MPS owns credit accounting. + try: + await report_completed_workflow_run_platform_usage(workflow_run_id) + except Exception as e: + logger.error( + f"Error reporting platform usage for workflow {workflow_run_id}: {e}" + ) + + logger.info(f"Completed workflow completion processing for run {workflow_run_id}") diff --git a/api/tests/integrations/_run_pipeline_helpers.py b/api/tests/integrations/_run_pipeline_helpers.py index 1a3251a0..58b4ffd2 100644 --- a/api/tests/integrations/_run_pipeline_helpers.py +++ b/api/tests/integrations/_run_pipeline_helpers.py @@ -203,7 +203,7 @@ async def create_workflow_run_rows( Returns: Tuple of (workflow_run, user, workflow). """ - from api.schemas.user_configuration import EffectiveAIModelConfiguration + from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}") async_session.add(org) diff --git a/api/tests/test_ai_model_configuration_v2.py b/api/tests/test_ai_model_configuration_v2.py index 98f431e8..57f7cf83 100644 --- a/api/tests/test_ai_model_configuration_v2.py +++ b/api/tests/test_ai_model_configuration_v2.py @@ -1,12 +1,16 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + import pytest from pydantic import ValidationError from api.schemas.ai_model_configuration import ( DograhManagedAIModelConfiguration, + EffectiveAIModelConfiguration, + OrganizationAIModelConfigurationResponse, OrganizationAIModelConfigurationV2, compile_ai_model_configuration_v2, ) -from api.schemas.user_configuration import EffectiveAIModelConfiguration from api.services.configuration.ai_model_configuration import ( WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY, check_for_masked_keys_in_ai_model_configuration_v2, @@ -15,6 +19,7 @@ from api.services.configuration.ai_model_configuration import ( merge_ai_model_configuration_v2_secrets, migrate_workflow_configuration_model_override_to_v2, ) +from api.services.configuration.check_validity import UserConfigurationValidator from api.services.configuration.masking import mask_key from api.services.configuration.registry import ( DeepgramSTTConfiguration, @@ -22,6 +27,8 @@ from api.services.configuration.registry import ( DograhSTTService, DograhTTSService, ElevenlabsTTSConfiguration, + GoogleLLMService, + GoogleRealtimeLLMConfiguration, OpenAIEmbeddingsConfiguration, OpenAILLMService, ) @@ -49,6 +56,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings(): assert effective.stt.language == "multi" assert effective.embeddings.provider == "dograh" assert effective.embeddings.model == "default" + assert effective.managed_service_version == 2 def test_dograh_v2_rejects_non_predefined_speed(): @@ -92,6 +100,67 @@ def test_byok_v2_rejects_dograh_provider(): ) +@pytest.mark.asyncio +async def test_byok_realtime_validator_does_not_require_stt_or_tts(): + config = OrganizationAIModelConfigurationV2.model_validate( + { + "mode": "byok", + "byok": { + "mode": "realtime", + "realtime": { + "realtime": { + "provider": "google_realtime", + "api_key": "google-realtime-key", + "model": "gemini-3.1-flash-live-preview", + "voice": "Puck", + "language": "en", + }, + "llm": { + "provider": "google", + "api_key": "google-llm-key", + "model": "gemini-2.0-flash", + }, + }, + }, + } + ) + effective = compile_ai_model_configuration_v2(config) + + assert effective.is_realtime is True + assert effective.stt is None + assert effective.tts is None + assert await UserConfigurationValidator().validate(effective) == { + "status": [{"model": "all", "message": "ok"}] + } + + +@pytest.mark.asyncio +async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime(): + effective = EffectiveAIModelConfiguration( + llm=GoogleLLMService( + provider="google", + api_key="google-llm-key", + model="gemini-2.0-flash", + ), + realtime=GoogleRealtimeLLMConfiguration( + provider="google_realtime", + api_key="google-realtime-key", + model="gemini-3.1-flash-live-preview", + voice="Puck", + language="en", + ), + is_realtime=False, + ) + + with pytest.raises(ValueError) as exc_info: + await UserConfigurationValidator().validate(effective) + + assert exc_info.value.args[0] == [ + {"model": "stt", "message": "API key is missing"}, + {"model": "tts", "message": "API key is missing"}, + ] + + def test_masked_dograh_key_is_preserved_when_saving_same_mode(): existing = OrganizationAIModelConfigurationV2( mode="dograh", @@ -293,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker(): assert changed is True assert "model_overrides" not in migrated assert migrated["ambient_noise_configuration"] == {"enabled": False} + + +@pytest.mark.asyncio +async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing( + monkeypatch, +): + from api.routes import organization as organization_routes + + legacy = EffectiveAIModelConfiguration( + llm=DograhLLMService( + provider="dograh", + api_key=["mps-secret"], + model="default", + ), + tts=DograhTTSService( + provider="dograh", + api_key=["mps-secret"], + model="default", + voice="default", + ), + stt=DograhSTTService( + provider="dograh", + api_key=["mps-secret"], + model="default", + ), + ) + expected_response = OrganizationAIModelConfigurationResponse( + configuration={"version": 2, "mode": "dograh"}, + effective_configuration={}, + source="organization_v2", + ) + + class FakeValidator: + async def validate(self, *args, **kwargs): + return {"status": [{"model": "all", "message": "ok"}]} + + ensure_billing = AsyncMock(return_value={"billing_mode": "v2"}) + upsert = AsyncMock() + migrate_workflows = AsyncMock() + + monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + organization_routes, + "get_organization_ai_model_configuration_v2", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + organization_routes.db_client, + "get_user_configurations", + AsyncMock(return_value=legacy), + ) + monkeypatch.setattr( + organization_routes, + "UserConfigurationValidator", + lambda: FakeValidator(), + ) + monkeypatch.setattr( + organization_routes, + "ensure_hosted_mps_billing_account_v2", + ensure_billing, + ) + monkeypatch.setattr( + organization_routes, + "upsert_organization_ai_model_configuration_v2", + upsert, + ) + monkeypatch.setattr( + organization_routes, + "migrate_workflow_model_configurations_to_v2", + migrate_workflows, + ) + monkeypatch.setattr( + organization_routes, + "_model_configuration_v2_response", + AsyncMock(return_value=expected_response), + ) + + user = SimpleNamespace( + id=7, + provider_id="provider-123", + selected_organization_id=42, + ) + + response = await organization_routes.migrate_model_configuration_v2( + force=False, + user=user, + ) + + ensure_billing.assert_awaited_once_with(42, created_by="provider-123") + upsert.assert_awaited_once() + migrate_workflows.assert_awaited_once_with( + organization_id=42, + fallback_user_config=legacy, + ) + assert response == expected_response diff --git a/api/tests/test_auth_depends.py b/api/tests/test_auth_depends.py new file mode 100644 index 00000000..2f33ff58 --- /dev/null +++ b/api/tests/test_auth_depends.py @@ -0,0 +1,68 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from api.services.auth import depends as auth_depends + + +@pytest.mark.asyncio +async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch): + stack_user = { + "id": "stack-user-1", + "selected_team_id": "team-1", + "primary_email_verified": False, + } + user = SimpleNamespace( + id=7, + email=None, + provider_id="stack-user-1", + selected_organization_id=None, + ) + organization = SimpleNamespace(id=42) + existing_config = SimpleNamespace(llm=object(), tts=None, stt=None) + + ensure_billing = AsyncMock(return_value={"billing_mode": "v2"}) + + monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack") + monkeypatch.setattr( + auth_depends.stackauth, + "get_user", + AsyncMock(return_value=stack_user), + ) + monkeypatch.setattr( + auth_depends.db_client, + "get_or_create_user_by_provider_id", + AsyncMock(return_value=(user, False)), + ) + monkeypatch.setattr( + auth_depends.db_client, + "get_or_create_organization_by_provider_id", + AsyncMock(return_value=(organization, True)), + ) + monkeypatch.setattr( + auth_depends.db_client, + "add_user_to_organization", + AsyncMock(), + ) + monkeypatch.setattr( + auth_depends.db_client, + "update_user_selected_organization", + AsyncMock(), + ) + monkeypatch.setattr( + auth_depends.db_client, + "get_user_configurations", + AsyncMock(return_value=existing_config), + ) + monkeypatch.setattr( + auth_depends, + "ensure_hosted_mps_billing_account_v2", + ensure_billing, + ) + + result = await auth_depends.get_user(authorization="Bearer token") + + assert result is user + assert result.selected_organization_id == 42 + ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1") diff --git a/api/tests/test_cost_calculator.py b/api/tests/test_cost_calculator.py deleted file mode 100644 index 940ac582..00000000 --- a/api/tests/test_cost_calculator.py +++ /dev/null @@ -1,31 +0,0 @@ -from api.services.pricing.cost_calculator import cost_calculator - - -def test_cost_calculator(): - """Test function to verify cost calculation works""" - sample_usage = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 45380, - "completion_tokens": 496, - "total_tokens": 45876, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399}, - "stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692}, - "call_duration_seconds": 179, - } - - result = cost_calculator.calculate_total_cost(sample_usage) - assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000 - assert result["tts_cost"] == 2399 * 0.0256 / 1_000 - assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077 - assert ( - abs( - result["total"] - - (result["llm_cost"] + result["tts_cost"] + result["stt_cost"]) - ) - < 1e-10 - ) diff --git a/api/tests/test_dograh_managed_correlation.py b/api/tests/test_dograh_managed_correlation.py new file mode 100644 index 00000000..b0cb52c0 --- /dev/null +++ b/api/tests/test_dograh_managed_correlation.py @@ -0,0 +1,110 @@ +import json + +import pytest +from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN +from pipecat.frames.frames import TTSStartedFrame +from pipecat.services.dograh.llm import DograhLLMService +from pipecat.services.dograh.stt import DograhSTTService +from pipecat.services.dograh.tts import DograhTTSService +from pipecat.services.openai.base_llm import OpenAILLMSettings +from websockets.protocol import State + + +class _FakeWebSocket: + def __init__(self): + self.state = State.OPEN + self.messages: list[dict] = [] + + async def send(self, message: str) -> None: + self.messages.append(json.loads(message)) + + async def close(self, *args, **kwargs) -> None: + self.state = State.CLOSED + + +def test_dograh_llm_uses_explicit_mps_correlation_id(): + service = DograhLLMService( + api_key="mps-secret", + correlation_id="mps-corr-123", + settings=OpenAILLMSettings(model="default"), + ) + service._start_metadata = {"workflow_run_id": 99} + + params = service.build_chat_completion_params( + { + "messages": [], + "tools": OPENAI_NOT_GIVEN, + "tool_choice": OPENAI_NOT_GIVEN, + } + ) + + assert params["metadata"]["correlation_id"] == "mps-corr-123" + assert params["metadata"]["mps_billing_version"] == "2" + + +@pytest.mark.asyncio +async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch): + fake_ws = _FakeWebSocket() + + async def fake_connect(url, additional_headers): + return fake_ws + + monkeypatch.setattr( + "pipecat.services.dograh.stt.websocket_connect", + fake_connect, + ) + + service = DograhSTTService( + api_key="mps-secret", + correlation_id="mps-corr-123", + sample_rate=16000, + ) + service._start_metadata = {"workflow_run_id": 99} + + await service._connect_websocket() + + assert fake_ws.messages[0]["type"] == "config" + assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123" + assert fake_ws.messages[0]["mps_billing_version"] == "2" + + +@pytest.mark.asyncio +async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch): + fake_ws = _FakeWebSocket() + + async def fake_connect(url, additional_headers): + return fake_ws + + monkeypatch.setattr( + "pipecat.services.dograh.tts.websocket_connect", + fake_connect, + ) + + service = DograhTTSService( + api_key="mps-secret", + correlation_id="mps-corr-123", + sample_rate=24000, + ) + service._start_metadata = {"workflow_run_id": 99} + + await service._connect_websocket() + assert fake_ws.messages[0]["type"] == "config" + assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123" + assert fake_ws.messages[0]["mps_billing_version"] == "2" + + async def _noop(*args, **kwargs): + return None + + service.audio_context_available = lambda context_id: False + service.create_audio_context = _noop + service.start_ttfb_metrics = _noop + service.start_tts_usage_metrics = _noop + + frames = [] + async for frame in service.run_tts("hello", "ctx-1"): + frames.append(frame) + + assert isinstance(frames[0], TTSStartedFrame) + assert fake_ws.messages[1]["type"] == "create_context" + assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123" + assert fake_ws.messages[1]["mps_billing_version"] == "2" diff --git a/api/tests/test_grok_realtime_wrapper.py b/api/tests/test_grok_realtime_wrapper.py index 7f7359dc..19cae657 100644 --- a/api/tests/test_grok_realtime_wrapper.py +++ b/api/tests/test_grok_realtime_wrapper.py @@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.frame_processor import FrameDirection from pipecat.services.xai.realtime import events -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import GrokRealtimeLLMConfiguration from api.services.pipecat.realtime.grok_realtime import ( DograhGrokRealtimeLLMService, @@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized(): def test_factory_creates_dograh_grok_realtime_service(): - user_config = EffectiveAIModelConfiguration( + effective_config = EffectiveAIModelConfiguration( is_realtime=True, realtime=GrokRealtimeLLMConfiguration( provider="grok_realtime", @@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service(): ) service = create_realtime_llm_service( - user_config, + effective_config, audio_config=SimpleNamespace(), ) diff --git a/api/tests/test_masked_key_rejection.py b/api/tests/test_masked_key_rejection.py index 2012c60b..45782335 100644 --- a/api/tests/test_masked_key_rejection.py +++ b/api/tests/test_masked_key_rejection.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from api.routes.user import router -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.auth.depends import get_user from api.services.configuration.masking import mask_key from api.services.configuration.registry import ( diff --git a/api/tests/test_mps_service_key_client.py b/api/tests/test_mps_service_key_client.py index 9cd629e3..032f07bf 100644 --- a/api/tests/test_mps_service_key_client.py +++ b/api/tests/test_mps_service_key_client.py @@ -87,3 +87,317 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch): "Content-Type": "application/json", }, ) + + +@pytest.mark.asyncio +async def test_create_correlation_id_uses_bearer_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response(200, {"correlation_id": "mps-corr-123"}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + + client = MPSServiceKeyClient() + + assert await client.create_correlation_id( + service_key="mps_sk_paid", + workflow_run_id=42, + ) == {"correlation_id": "mps-corr-123"} + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/service-keys/correlation-id/self", + {"workflow_run_id": 42}, + { + "Authorization": "Bearer mps_sk_paid", + "Content-Type": "application/json", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def get(self, url, headers): + calls.append(("GET", url, headers)) + return _Response(200, {"organization_id": 42, "billing_mode": "v2"}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.get_billing_account_status(organization_id=42) == { + "organization_id": 42, + "billing_mode": "v2", + } + assert calls == [ + ( + "GET", + f"{client.base_url}/api/v1/billing/accounts/42/status", + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def get(self, url, headers): + calls.append(("GET", url, headers)) + return _Response( + 200, + { + "id": 7, + "organization_id": 42, + "billing_mode": "v2", + "cached_balance_credits": "0.0000", + "currency": "USD", + }, + ) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.ensure_billing_account_v2( + organization_id=42, + created_by="provider-123", + ) == { + "id": 7, + "organization_id": 42, + "billing_mode": "v2", + "cached_balance_credits": "0.0000", + "currency": "USD", + } + assert calls == [ + ( + "GET", + f"{client.base_url}/api/v1/billing/accounts/42/balance", + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_get_credit_ledger_sends_page_and_limit(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def get(self, url, params, headers): + calls.append(("GET", url, params, headers)) + return _Response( + 200, + { + "account": {"organization_id": 42}, + "ledger_entries": [], + "total_count": 0, + "page": 3, + "limit": 25, + "total_pages": 0, + }, + ) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.get_credit_ledger( + organization_id=42, + page=3, + limit=25, + ) == { + "account": {"organization_id": 42}, + "ledger_entries": [], + "total_count": 0, + "page": 3, + "limit": 25, + "total_pages": 0, + } + assert calls == [ + ( + "GET", + f"{client.base_url}/api/v1/billing/accounts/42/ledger", + {"page": 3, "limit": 25}, + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response(200, {"metered": True}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.report_platform_usage( + organization_id=42, + correlation_id="mps-corr-123", + workflow_run_id=123, + metadata={"source": "workflow_run_completion"}, + ) == {"metered": True} + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/billing/accounts/42/platform-usage", + { + "correlation_id": "mps-corr-123", + "workflow_run_id": 123, + "metadata": {"source": "workflow_run_completion"}, + }, + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response(200, {"metered": True}) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.report_platform_usage( + organization_id=42, + duration_seconds=87.0, + workflow_run_id=123, + metadata={"source": "workflow_run_completion"}, + ) == {"metered": True} + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/billing/accounts/42/platform-usage", + { + "duration_seconds": 87.0, + "workflow_run_id": 123, + "metadata": {"source": "workflow_run_completion"}, + }, + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] diff --git a/api/tests/test_organization_usage_billing.py b/api/tests/test_organization_usage_billing.py new file mode 100644 index 00000000..2f813eac --- /dev/null +++ b/api/tests/test_organization_usage_billing.py @@ -0,0 +1,99 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from api.routes import organization_usage + + +def test_is_mps_billing_v2_depends_only_on_account_mode(): + assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True + assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False + assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False + assert organization_usage._is_mps_billing_v2(None) is False + + +@pytest.mark.asyncio +async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch): + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + monkeypatch.setattr( + organization_usage.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + + user = SimpleNamespace(provider_id="provider-123") + + assert await organization_usage._get_mps_billing_account_status(user, 42) == { + "billing_mode": "v2" + } + get_status.assert_awaited_once_with( + organization_id=42, + created_by="provider-123", + ) + + +@pytest.mark.asyncio +async def test_get_billing_credits_pages_v2_ledger(monkeypatch): + monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + organization_usage, + "_get_mps_billing_account_status", + AsyncMock(return_value={"billing_mode": "v2"}), + ) + get_ledger = AsyncMock( + return_value={ + "account": { + "id": 7, + "organization_id": 42, + "billing_mode": "v2", + "cached_balance_credits": 250, + "currency": "USD", + }, + "ledger_entries": [ + { + "id": 99, + "entry_type": "grant", + "origin": "account_creation", + "credits_delta": 250, + "balance_after": 250, + "created_at": "2026-06-12T00:00:00Z", + } + ], + "total_debits_credits": 75, + "total_count": 101, + "page": 3, + "limit": 25, + "total_pages": 5, + } + ) + monkeypatch.setattr( + organization_usage.mps_service_key_client, + "get_credit_ledger", + get_ledger, + ) + + user = SimpleNamespace( + provider_id="provider-123", + selected_organization_id=42, + ) + + response = await organization_usage.get_billing_credits( + page=3, + limit=25, + user=user, + ) + + get_ledger.assert_awaited_once_with( + organization_id=42, + page=3, + limit=25, + created_by="provider-123", + ) + assert response.billing_version == "v2" + assert response.total_credits_used == 75 + assert response.total_count == 101 + assert response.page == 3 + assert response.limit == 25 + assert response.total_pages == 5 + assert response.ledger_entries[0].id == 99 diff --git a/api/tests/test_resolve_effective_config.py b/api/tests/test_resolve_effective_config.py index 85afe30f..1b9ad8c6 100644 --- a/api/tests/test_resolve_effective_config.py +++ b/api/tests/test_resolve_effective_config.py @@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve import pytest -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.masking import ( contains_masked_key, mask_workflow_configurations, diff --git a/api/tests/test_run_usage_response.py b/api/tests/test_run_usage_response.py index c17d4a9f..044c6563 100644 --- a/api/tests/test_run_usage_response.py +++ b/api/tests/test_run_usage_response.py @@ -1,4 +1,4 @@ -from api.services.pricing.run_usage_response import format_public_usage_info +from api.services.workflow.run_usage_response import format_public_usage_info def test_format_public_usage_info(): diff --git a/api/tests/test_ultravox_realtime_wrapper.py b/api/tests/test_ultravox_realtime_wrapper.py index 65b062b6..32888439 100644 --- a/api/tests/test_ultravox_realtime_wrapper.py +++ b/api/tests/test_ultravox_realtime_wrapper.py @@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection from websockets.exceptions import ConnectionClosedError from websockets.frames import Close -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration from api.services.pipecat.realtime.ultravox_realtime import ( _RESUMPTION_USER_MESSAGE, @@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close(): def test_factory_creates_dograh_ultravox_realtime_service(): - user_config = EffectiveAIModelConfiguration( + effective_config = EffectiveAIModelConfiguration( is_realtime=True, realtime=UltravoxRealtimeLLMConfiguration( provider="ultravox_realtime", @@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service(): ) service = create_realtime_llm_service( - user_config, + effective_config, audio_config=SimpleNamespace(), ) diff --git a/api/tests/test_workflow_run_billing.py b/api/tests/test_workflow_run_billing.py new file mode 100644 index 00000000..2837317f --- /dev/null +++ b/api/tests/test_workflow_run_billing.py @@ -0,0 +1,212 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from api.services import workflow_run_billing as workflow_run_billing_mod +from api.services.workflow_run_billing import ( + report_completed_workflow_run_platform_usage, + report_workflow_run_platform_usage, +) + + +def _make_workflow_run(): + return SimpleNamespace( + id=123, + workflow_id=456, + is_completed=True, + initial_context={"mps_correlation_id": "mps-corr-123"}, + usage_info={"call_duration_seconds": 87}, + workflow=SimpleNamespace( + organization_id=42, + user=SimpleNamespace(selected_organization_id=42), + ), + ) + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_reports_hosted_completion( + monkeypatch, +): + workflow_run = _make_workflow_run() + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock(return_value={"metered": True}) + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_awaited_once_with( + organization_id=42, + correlation_id="mps-corr-123", + duration_seconds=None, + workflow_run_id=workflow_run.id, + metadata={ + "source": "workflow_run_completion", + "workflow_id": workflow_run.workflow_id, + "duration_source": "mps_correlation", + }, + ) + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_reports_duration_without_correlation( + monkeypatch, +): + workflow_run = _make_workflow_run() + workflow_run.initial_context = {} + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock(return_value={"metered": True}) + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_awaited_once_with( + organization_id=42, + correlation_id=None, + duration_seconds=87.0, + workflow_run_id=workflow_run.id, + metadata={ + "source": "workflow_run_completion", + "workflow_id": workflow_run.workflow_id, + "duration_source": "dograh_usage_info", + }, + ) + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch): + workflow_run = _make_workflow_run() + get_status = AsyncMock(return_value={"billing_mode": "v1"}) + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + get_status.assert_awaited_once_with(organization_id=42) + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation( + monkeypatch, +): + workflow_run = _make_workflow_run() + workflow_run.initial_context = {} + workflow_run.usage_info = {} + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + get_status.assert_not_awaited() + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch): + workflow_run = _make_workflow_run() + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch): + workflow_run = _make_workflow_run() + workflow_run.is_completed = False + report_usage = AsyncMock() + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_workflow_run_platform_usage(workflow_run) + + report_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch): + workflow_run = _make_workflow_run() + get_run = AsyncMock(return_value=workflow_run) + get_status = AsyncMock(return_value={"billing_mode": "v2"}) + report_usage = AsyncMock(return_value={"metered": True}) + + monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + workflow_run_billing_mod.db_client, + "get_workflow_run_by_id", + get_run, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "get_billing_account_status", + get_status, + ) + monkeypatch.setattr( + workflow_run_billing_mod.mps_service_key_client, + "report_platform_usage", + report_usage, + ) + + await report_completed_workflow_run_platform_usage(workflow_run.id) + + get_run.assert_awaited_once_with(workflow_run.id) + report_usage.assert_awaited_once() diff --git a/api/tests/test_workflow_run_cost.py b/api/tests/test_workflow_run_cost.py deleted file mode 100644 index c77424c8..00000000 --- a/api/tests/test_workflow_run_cost.py +++ /dev/null @@ -1,181 +0,0 @@ -from datetime import UTC, datetime -from types import SimpleNamespace -from unittest.mock import AsyncMock - -import pytest - -from api.services.pricing import workflow_run_cost as workflow_run_cost_mod -from api.services.pricing.workflow_run_cost import ( - apply_usage_delta_to_organization, - build_workflow_run_cost_info, - calculate_workflow_run_cost, -) - - -def _make_workflow_run(): - return SimpleNamespace( - id=123, - workflow_id=456, - mode="textchat", - created_at=datetime.now(UTC), - usage_info={ - "llm": {}, - "tts": {}, - "stt": {}, - "call_duration_seconds": 7, - }, - cost_info={}, - workflow=SimpleNamespace( - organization_id=42, - user=SimpleNamespace(selected_organization_id=42), - ), - ) - - -@pytest.mark.asyncio -async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch): - workflow_run = _make_workflow_run() - get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5)) - update_usage = AsyncMock() - - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "get_organization_by_id", get_org - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage - ) - - cost_info = await build_workflow_run_cost_info(workflow_run) - - assert cost_info is not None - assert cost_info["call_duration_seconds"] == 7 - assert "cost_breakdown" in cost_info - assert "dograh_token_usage" in cost_info - assert cost_info["charge_usd"] == 10.5 - update_usage.assert_not_called() - - -@pytest.mark.asyncio -async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper( - monkeypatch, -): - workflow_run = _make_workflow_run() - get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None)) - update_run = AsyncMock() - update_usage = AsyncMock() - - monkeypatch.setattr( - workflow_run_cost_mod.db_client, - "get_workflow_run_by_id", - AsyncMock(return_value=workflow_run), - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "get_organization_by_id", get_org - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_workflow_run", update_run - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage - ) - - await calculate_workflow_run_cost(workflow_run.id) - - update_run.assert_awaited_once() - saved_kwargs = update_run.await_args.kwargs - assert saved_kwargs["run_id"] == workflow_run.id - assert "cost_breakdown" in saved_kwargs["cost_info"] - update_usage.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_apply_usage_delta_to_organization_uses_incremental_costs( - monkeypatch, -): - workflow_run = _make_workflow_run() - workflow_run.cost_info = {"call_id": "preserve-me"} - - usage_delta_one = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 1_000, - "completion_tokens": 100, - "total_tokens": 1_100, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {}, - "stt": {}, - "call_duration_seconds": 3, - } - usage_delta_two = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 2_000, - "completion_tokens": 50, - "total_tokens": 2_050, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {}, - "stt": {}, - "call_duration_seconds": 4, - } - merged_usage = { - "llm": { - "OpenAILLMService#0|||gpt-4.1-mini": { - "prompt_tokens": 3_000, - "completion_tokens": 150, - "total_tokens": 3_150, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - } - }, - "tts": {}, - "stt": {}, - "call_duration_seconds": 7, - } - - get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5)) - update_usage = AsyncMock() - - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "get_organization_by_id", get_org - ) - monkeypatch.setattr( - workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage - ) - - first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one) - second_delta = await apply_usage_delta_to_organization( - workflow_run, usage_delta_two - ) - total_workflow_run = SimpleNamespace(**workflow_run.__dict__) - total_workflow_run.usage_info = merged_usage - total_cost = await build_workflow_run_cost_info(total_workflow_run) - - assert first_delta is not None - assert second_delta is not None - assert total_cost is not None - assert update_usage.await_count == 2 - assert update_usage.await_args_list[0].args == ( - 42, - first_delta["dograh_token_usage"], - 3.0, - first_delta["charge_usd"], - ) - assert update_usage.await_args_list[1].args == ( - 42, - second_delta["dograh_token_usage"], - 4.0, - second_delta["charge_usd"], - ) - assert ( - first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"] - ) == pytest.approx(total_cost["dograh_token_usage"]) - assert ( - first_delta["charge_usd"] + second_delta["charge_usd"] - == total_cost["charge_usd"] - ) diff --git a/api/tests/test_workflow_text_chat.py b/api/tests/test_workflow_text_chat.py index e69e7c0a..3be8a613 100644 --- a/api/tests/test_workflow_text_chat.py +++ b/api/tests/test_workflow_text_chat.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch import pytest from api.db.models import OrganizationModel, UserModel -from api.schemas.user_configuration import EffectiveAIModelConfiguration +from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION from pipecat.tests import MockLLMService @@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn( assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", []) workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"]) assert workflow_run is not None - assert workflow_run.cost_info[ - "call_duration_seconds" - ] == workflow_run.usage_info.get("call_duration_seconds", 0) - assert "cost_breakdown" in workflow_run.cost_info - assert "dograh_token_usage" in workflow_run.cost_info + assert "call_duration_seconds" in workflow_run.usage_info assert _log_texts(run_payload["logs"], "rtf-bot-text") == [ "Hello from the workflow tester." ] @@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn( assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", []) workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"]) assert workflow_run is not None - assert workflow_run.cost_info[ - "call_duration_seconds" - ] == workflow_run.usage_info.get("call_duration_seconds", 0) - assert "cost_breakdown" in workflow_run.cost_info - assert "dograh_token_usage" in workflow_run.cost_info + assert "call_duration_seconds" in workflow_run.usage_info assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"] assert _log_texts(run_payload["logs"], "rtf-bot-text") == [ "Welcome to the workflow tester.", diff --git a/pipecat b/pipecat index 228324a1..0d64dc6e 160000 --- a/pipecat +++ b/pipecat @@ -1 +1 @@ -Subproject commit 228324a146a6765c6b8d610963bc80d7bc8cb9f7 +Subproject commit 0d64dc6e0e3e6b3c46cc66373e34b4f54f980268 diff --git a/ui/src/app/billing/page.tsx b/ui/src/app/billing/page.tsx new file mode 100644 index 00000000..0a9732c9 --- /dev/null +++ b/ui/src/app/billing/page.tsx @@ -0,0 +1,416 @@ +"use client"; + +import { + ChevronLeft, + ChevronRight, + CircleDollarSign, + CreditCard, + RefreshCw, +} from "lucide-react"; +import Link from "next/link"; +import { useRouter, useSearchParams } from "next/navigation"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; + +import { createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost, getBillingCreditsApiV1OrganizationsBillingCreditsGet } from "@/client/sdk.gen"; +import type { MpsBillingCreditsResponse, MpsCreditLedgerEntryResponse } from "@/client/types.gen"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Progress } from "@/components/ui/progress"; +import { Skeleton } from "@/components/ui/skeleton"; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { useAppConfig } from "@/context/AppConfigContext"; +import { useAuth } from "@/lib/auth"; + +const LEDGER_PAGE_SIZE = 50; + +const formatCredits = (value: number | null | undefined) => ( + (value ?? 0).toLocaleString(undefined, { + maximumFractionDigits: 2, + minimumFractionDigits: 0, + }) +); + +const formatAmount = (amountMinor?: number | null, currency?: string | null) => { + if (amountMinor == null) { + return "-"; + } + + return new Intl.NumberFormat(undefined, { + style: "currency", + currency: currency || "USD", + }).format(amountMinor / 100); +}; + +const formatDate = (value: string) => ( + new Date(value).toLocaleString(undefined, { + month: "short", + day: "numeric", + year: "numeric", + hour: "2-digit", + minute: "2-digit", + }) +); + +const metricLabels: Record = { + voice_minutes: "Voice usage", + platform_usage: "Platform usage", +}; + +const formatTitleCase = (value: string | null | undefined) => ( + value ? value.replaceAll("_", " ").replace(/\b\w/g, (letter) => letter.toUpperCase()) : "-" +); + +const getLedgerEntryLabel = (entry: MpsCreditLedgerEntryResponse) => { + if (entry.metric_code) { + return metricLabels[entry.metric_code] ?? formatTitleCase(entry.metric_code); + } + + if (entry.entry_type === "grant") { + return "Credit grant"; + } + + if (entry.entry_type === "purchase") { + return "Credit purchase"; + } + + return formatTitleCase(entry.entry_type); +}; + +const formatBillableQuantity = (entry: MpsCreditLedgerEntryResponse) => { + if (entry.billable_quantity == null || !entry.quantity_unit) { + return null; + } + + const unit = entry.quantity_unit === "minute" ? "min" : entry.quantity_unit; + return `${formatCredits(entry.billable_quantity)} ${unit}`; +}; + +const getRunHref = (entry: MpsCreditLedgerEntryResponse) => { + if (!entry.workflow_id || !entry.workflow_run_id) { + return null; + } + + return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`; +}; + +const getPageFromSearchParams = ( + searchParams: { get: (name: string) => string | null }, +) => { + const pageParam = searchParams.get("page"); + const page = pageParam ? Number.parseInt(pageParam, 10) : 1; + return Number.isFinite(page) && page > 0 ? page : 1; +}; + +export default function BillingPage() { + const router = useRouter(); + const searchParams = useSearchParams(); + const auth = useAuth(); + const { config } = useAppConfig(); + const [credits, setCredits] = useState(null); + const [loading, setLoading] = useState(true); + const [refreshing, setRefreshing] = useState(false); + const [purchasing, setPurchasing] = useState(false); + const [currentPage, setCurrentPage] = useState( + () => getPageFromSearchParams(searchParams), + ); + + const isBillingV2 = credits?.billing_version === "v2"; + const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss"; + const totalQuota = credits?.total_quota ?? 0; + const remainingCredits = credits?.remaining_credits ?? 0; + const usedCredits = credits?.total_credits_used ?? 0; + const usagePercent = totalQuota > 0 ? Math.min(100, Math.round((usedCredits / totalQuota) * 100)) : 0; + + const ledgerEntries = useMemo(() => credits?.ledger_entries ?? [], [credits?.ledger_entries]); + const ledgerPage = credits?.page ?? currentPage; + const ledgerTotalCount = credits?.total_count ?? ledgerEntries.length; + const ledgerTotalPages = credits?.total_pages ?? 0; + + const fetchCredits = useCallback(async ( + page: number, + { silent = false }: { silent?: boolean } = {}, + ) => { + if (auth.loading) { + return; + } + + if (!auth.isAuthenticated) { + setLoading(false); + return; + } + + if (silent) { + setRefreshing(true); + } else { + setLoading(true); + } + + try { + const response = await getBillingCreditsApiV1OrganizationsBillingCreditsGet({ + query: { page, limit: LEDGER_PAGE_SIZE }, + }); + + if (response.error) { + throw new Error("Failed to fetch billing credits"); + } + + setCredits(response.data ?? null); + } catch (error) { + console.error("Failed to fetch billing credits:", error); + toast.error("Failed to fetch billing credits"); + } finally { + setLoading(false); + setRefreshing(false); + } + }, [auth.isAuthenticated, auth.loading]); + + useEffect(() => { + const nextPage = getPageFromSearchParams(searchParams); + setCurrentPage((previousPage) => ( + previousPage === nextPage ? previousPage : nextPage + )); + }, [searchParams]); + + useEffect(() => { + fetchCredits(currentPage); + }, [currentPage, fetchCredits]); + + const handleRefresh = () => { + fetchCredits(currentPage, { silent: true }); + }; + + const updateUrlPage = useCallback((page: number) => { + const newParams = new URLSearchParams(searchParams.toString()); + if (page > 1) { + newParams.set("page", page.toString()); + } else { + newParams.delete("page"); + } + + const queryString = newParams.toString(); + router.push(queryString ? `/billing?${queryString}` : "/billing"); + }, [router, searchParams]); + + const handlePageChange = (page: number) => { + const nextPage = Math.max(1, page); + setCurrentPage(nextPage); + updateUrlPage(nextPage); + }; + + const handlePurchaseCredits = async () => { + if (!canPurchaseCredits) { + return; + } + + setPurchasing(true); + try { + const response = await createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost(); + const checkoutUrl = response.data?.checkout_url; + if (!checkoutUrl) { + throw new Error("Missing checkout URL"); + } + window.location.href = checkoutUrl; + } catch (error) { + console.error("Failed to create credit purchase URL:", error); + toast.error("Failed to open checkout"); + setPurchasing(false); + } + }; + + if (loading) { + return ( +
+
+ + +
+
+ + +
+ +
+ ); + } + + return ( +
+
+
+

Billing

+

+ Credits, balance, and account usage for your organization. +

+
+
+ + {canPurchaseCredits && ( + + )} +
+
+ +
+ + + {isBillingV2 ? "Credit balance" : "Credits remaining"} + + + {formatCredits(remainingCredits)} + + + +

1 credit = 1 cent

+
+
+ + + + Credits used + {formatCredits(usedCredits)} + + +

+ {isBillingV2 ? "Total ledger debits" : "Current allocation usage"} +

+
+
+
+ + {isBillingV2 ? ( + + + Credit Ledger + Recent grants, purchases, and usage debits. + + + {ledgerEntries.length > 0 ? ( +
+ + + + Date + Activity + Origin + Run + Delta + Balance + Amount + + + + {ledgerEntries.map((entry) => { + const delta = entry.credits_delta ?? 0; + const runHref = getRunHref(entry); + const billableQuantity = formatBillableQuantity(entry); + return ( + + {formatDate(entry.created_at)} + +
+ {getLedgerEntryLabel(entry)} + {billableQuantity && ( + {billableQuantity} + )} +
+
+ + {entry.origin ? ( + {formatTitleCase(entry.origin)} + ) : ( + "-" + )} + + + {entry.workflow_run_id ? ( + runHref ? ( + + #{entry.workflow_run_id} + + ) : ( + #{entry.workflow_run_id} + ) + ) : ( + "-" + )} + + = 0 ? "text-green-600" : "text-destructive"}`}> + {delta >= 0 ? "+" : ""} + {formatCredits(delta)} + + {formatCredits(entry.balance_after)} + + {formatAmount(entry.amount_minor, entry.amount_currency)} + +
+ ); + })} +
+
+
+ ) : ( +
+ No ledger entries yet +
+ )} + {ledgerTotalPages > 1 && ( +
+

+ Page {ledgerPage} of {ledgerTotalPages} ({ledgerTotalCount} total entries) +

+
+ + +
+
+ )} +
+
+ ) : ( + + + Credit Usage + + + +
+ {usagePercent}% used + {formatCredits(remainingCredits)} of {formatCredits(totalQuota)} remaining +
+
+
+ )} +
+ ); +} diff --git a/ui/src/app/layout.tsx b/ui/src/app/layout.tsx index b7961425..9b346354 100644 --- a/ui/src/app/layout.tsx +++ b/ui/src/app/layout.tsx @@ -12,8 +12,8 @@ import SpinLoader from "@/components/SpinLoader"; import { Toaster } from "@/components/ui/sonner"; import { AppConfigProvider } from "@/context/AppConfigContext"; import { OnboardingProvider } from "@/context/OnboardingContext"; +import { OrgConfigProvider } from "@/context/OrgConfigContext"; import { TelephonyConfigWarningsProvider } from "@/context/TelephonyConfigWarningsContext"; -import { UserConfigProvider } from "@/context/UserConfigContext"; import { AuthProvider } from "@/lib/auth"; @@ -65,7 +65,7 @@ export default function RootLayout({ }> - + @@ -76,7 +76,7 @@ export default function RootLayout({ - + diff --git a/ui/src/app/reports/page.tsx b/ui/src/app/reports/page.tsx index 770e1724..3f703bdc 100644 --- a/ui/src/app/reports/page.tsx +++ b/ui/src/app/reports/page.tsx @@ -2,7 +2,7 @@ import { addDays, format, subDays } from 'date-fns'; import { Calendar, ChevronLeft, ChevronRight, Download } from 'lucide-react'; -import { useEffect,useState } from 'react'; +import { useEffect, useState } from 'react'; import { getDailyReportApiV1OrganizationsReportsDailyGet, @@ -201,7 +201,9 @@ export default function ReportsPage() {
{/* Header */}
-

Daily Reports

+
+

Daily Reports

+
{/* Date Navigation & Workflow Selector */}
diff --git a/ui/src/app/usage/page.tsx b/ui/src/app/usage/page.tsx index 181d1791..a5d69839 100644 --- a/ui/src/app/usage/page.tsx +++ b/ui/src/app/usage/page.tsx @@ -6,8 +6,8 @@ import { useCallback, useEffect, useId, useState } from 'react'; import TimezoneSelect, { type ITimezoneOption } from 'react-timezone-select'; import { toast } from 'sonner'; -import { downloadUsageRunsReportApiV1OrganizationsUsageRunsReportGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet, getPreferencesApiV1OrganizationsPreferencesGet, getUsageHistoryApiV1OrganizationsUsageRunsGet, savePreferencesApiV1OrganizationsPreferencesPut } from '@/client/sdk.gen'; -import type { DailyUsageBreakdownResponse, MpsCreditsResponse, OrganizationPreferences, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen'; +import { downloadUsageRunsReportApiV1OrganizationsUsageRunsReportGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getPreferencesApiV1OrganizationsPreferencesGet, getUsageHistoryApiV1OrganizationsUsageRunsGet, savePreferencesApiV1OrganizationsPreferencesPut } from '@/client/sdk.gen'; +import type { DailyUsageBreakdownResponse, OrganizationPreferences, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen'; import { CallTypeCell } from '@/components/CallTypeCell'; import { DailyUsageTable } from '@/components/DailyUsageTable'; import { FilterBuilder } from '@/components/filters/FilterBuilder'; @@ -15,7 +15,6 @@ import { MediaPreviewButton, MediaPreviewDialog } from '@/components/MediaPrevie import { Badge } from '@/components/ui/badge'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; -import { Progress } from '@/components/ui/progress'; import { Table, TableBody, @@ -39,10 +38,6 @@ export default function UsagePage() { const { organizationPricing } = useUserConfig(); const auth = useAuth(); - // MPS credits state - const [mpsCredits, setMpsCredits] = useState(null); - const [isLoadingCredits, setIsLoadingCredits] = useState(true); - // Usage history state const [usageHistory, setUsageHistory] = useState(null); const [isLoadingHistory, setIsLoadingHistory] = useState(false); @@ -78,21 +73,6 @@ export default function UsagePage() { const [preferencesLoading, setPreferencesLoading] = useState(true); const timezoneSelectId = useId(); // Stable ID for react-select to prevent hydration mismatch - // Fetch MPS credits - const fetchMpsCredits = useCallback(async () => { - if (!auth.isAuthenticated) return; - try { - const response = await getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet(); - if (response.data) { - setMpsCredits(response.data); - } - } catch (error) { - console.error('Failed to fetch MPS credits:', error); - } finally { - setIsLoadingCredits(false); - } - }, [auth.isAuthenticated]); - // Translate the FilterBuilder state into the query-param shape the // backend expects. Shared between the listing fetch and the CSV export // so they stay in lockstep. @@ -251,10 +231,9 @@ export default function UsagePage() { // Initial load - fetch when auth becomes available useEffect(() => { if (auth.isAuthenticated) { - fetchMpsCredits(); fetchUsageHistory(currentPage, appliedFilters); } - }, [auth.isAuthenticated, currentPage, appliedFilters, fetchUsageHistory, fetchMpsCredits]); + }, [auth.isAuthenticated, currentPage, appliedFilters, fetchUsageHistory]); // Fetch daily usage when organizationPricing becomes available useEffect(() => { @@ -428,46 +407,6 @@ export default function UsagePage() {
- {/* MPS Credits Card */} - - - Dograh Model Credits - - These track usage of Dograh models using Dograh Service Keys. - - - - {isLoadingCredits ? ( -
-
-
-
-
- ) : mpsCredits ? ( -
-
-
-

- {mpsCredits.total_credits_used.toFixed(2)} / {mpsCredits.total_quota.toFixed(2)} -

-

Credits Used

-
-
-

{mpsCredits.remaining_credits.toFixed(2)}

-

Remaining

-
-
- - {mpsCredits.total_quota > 0 && ( - - )} -
- ) : ( -

No Dograh service keys configured. Set up a service key in your model configuration to see usage.

- )} -
-
- {/* Daily Usage Table - Only for paid organizations */} {organizationPricing?.price_per_second_usd && (
@@ -535,9 +474,9 @@ export default function UsagePage() { Disposition Date Duration - - {organizationPricing?.price_per_second_usd ? 'Cost (USD)' : 'Tokens'} - + {organizationPricing?.price_per_second_usd && ( + Cost (USD) + )} Actions @@ -574,12 +513,14 @@ export default function UsagePage() { {formatDuration(run.call_duration_seconds)} - - {organizationPricing?.price_per_second_usd && run.charge_usd !== undefined && run.charge_usd !== null - ? `$${run.charge_usd.toFixed(2)}` - : run.dograh_token_usage.toLocaleString() - } - + {organizationPricing?.price_per_second_usd && ( + + {run.charge_usd !== undefined && run.charge_usd !== null + ? `$${run.charge_usd.toFixed(2)}` + : '-' + } + + )} = Options2 & { /** @@ -915,6 +915,13 @@ export const refreshMcpToolsApiV1ToolsToolUuidMcpRefreshPost = (options: Options) => (options.client ?? client).post({ url: '/api/v1/tools/{tool_uuid}/unarchive', ...options }); +/** + * Get Current Organization Context + * + * Return organization-scoped configuration signals owned by Dograh. + */ +export const getCurrentOrganizationContextApiV1OrganizationsContextGet = (options?: Options) => (options?.client ?? client).get({ url: '/api/v1/organizations/context', ...options }); + /** * Get Telephony Providers Metadata * @@ -1232,7 +1239,7 @@ export const reactivateServiceKeyApiV1UserServiceKeysServiceKeyIdReactivatePut = /** * Get Current Period Usage * - * Get current billing period usage for the user's organization. + * Get current reporting-period usage for the user's organization. */ export const getCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGet = (options?: Options) => (options?.client ?? client).get({ url: '/api/v1/organizations/usage/current-period', ...options }); @@ -1246,6 +1253,20 @@ export const getCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGet = (options?: Options) => (options?.client ?? client).get({ url: '/api/v1/organizations/usage/mps-credits', ...options }); +/** + * Get Billing Credits + * + * Return legacy MPS credits or paginated v2 billing ledger details for the org. + */ +export const getBillingCreditsApiV1OrganizationsBillingCreditsGet = (options?: Options) => (options?.client ?? client).get({ url: '/api/v1/organizations/billing/credits', ...options }); + +/** + * Create Mps Credit Purchase Url + * + * Create a checkout URL for organizations using Dograh-managed MPS v2. + */ +export const createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost = (options?: Options) => (options?.client ?? client).post({ url: '/api/v1/organizations/usage/mps-credits/purchase-url', ...options }); + /** * Get Usage History * diff --git a/ui/src/client/types.gen.ts b/ui/src/client/types.gen.ts index a44e7565..4c6afb36 100644 --- a/ui/src/client/types.gen.ts +++ b/ui/src/client/types.gen.ts @@ -1642,22 +1642,6 @@ export type CurrentUsageResponse = { * Used Dograh Tokens */ used_dograh_tokens: number; - /** - * Quota Dograh Tokens - */ - quota_dograh_tokens: number; - /** - * Percentage Used - */ - percentage_used: number; - /** - * Next Refresh Date - */ - next_refresh_date: string; - /** - * Quota Enabled - */ - quota_enabled: boolean; /** * Total Duration Seconds */ @@ -1666,10 +1650,6 @@ export type CurrentUsageResponse = { * Used Amount Usd */ used_amount_usd?: number | null; - /** - * Quota Amount Usd - */ - quota_amount_usd?: number | null; /** * Currency */ @@ -3107,6 +3087,165 @@ export type LoginRequest = { password: string; }; +/** + * MPSBillingAccountResponse + */ +export type MpsBillingAccountResponse = { + /** + * Id + */ + id: number; + /** + * Organization Id + */ + organization_id: number; + /** + * Billing Mode + */ + billing_mode: string; + /** + * Cached Balance Credits + */ + cached_balance_credits: number; + /** + * Currency + */ + currency: string; +}; + +/** + * MPSBillingCreditsResponse + */ +export type MpsBillingCreditsResponse = { + /** + * Billing Version + */ + billing_version: 'legacy' | 'v2'; + /** + * Total Credits Used + */ + total_credits_used?: number; + /** + * Remaining Credits + */ + remaining_credits?: number; + /** + * Total Quota + */ + total_quota?: number; + account?: MpsBillingAccountResponse | null; + /** + * Ledger Entries + */ + ledger_entries?: Array; + /** + * Total Count + */ + total_count?: number; + /** + * Page + */ + page?: number; + /** + * Limit + */ + limit?: number; + /** + * Total Pages + */ + total_pages?: number; +}; + +/** + * MPSCreditLedgerEntryResponse + */ +export type MpsCreditLedgerEntryResponse = { + /** + * Id + */ + id: number; + /** + * Entry Type + */ + entry_type: string; + /** + * Origin + */ + origin?: string | null; + /** + * Credits Delta + */ + credits_delta: number; + /** + * Balance After + */ + balance_after: number; + /** + * Amount Minor + */ + amount_minor?: number | null; + /** + * Amount Currency + */ + amount_currency?: string | null; + /** + * Payment Order Id + */ + payment_order_id?: number | null; + /** + * Metric Code + */ + metric_code?: string | null; + /** + * Correlation Id + */ + correlation_id?: string | null; + /** + * Aggregation Key + */ + aggregation_key?: string | null; + /** + * Usage Event Id + */ + usage_event_id?: number | null; + /** + * Workflow Run Id + */ + workflow_run_id?: number | null; + /** + * Workflow Id + */ + workflow_id?: number | null; + /** + * Billable Quantity + */ + billable_quantity?: number | null; + /** + * Quantity Unit + */ + quantity_unit?: string | null; + /** + * Metadata + */ + metadata?: { + [key: string]: unknown; + }; + /** + * Created At + */ + created_at: string; +}; + +/** + * MPSCreditPurchaseUrlResponse + */ +export type MpsCreditPurchaseUrlResponse = { + /** + * Checkout Url + */ + checkout_url: string; +}; + /** * MPSCreditsResponse */ @@ -3618,6 +3757,43 @@ export type OrganizationAiModelConfigurationV2 = { byok?: ByokaiModelConfiguration | null; }; +/** + * OrganizationContextResponse + */ +export type OrganizationContextResponse = { + /** + * Organization Id + */ + organization_id?: number | null; + /** + * Organization Provider Id + */ + organization_provider_id?: string | null; + model_services: OrganizationModelServicesContext; +}; + +/** + * OrganizationModelServicesContext + */ +export type OrganizationModelServicesContext = { + /** + * Config Source + */ + config_source: 'organization_v2' | 'legacy_user_v1' | 'empty'; + /** + * Has Model Configuration V2 + */ + has_model_configuration_v2: boolean; + /** + * Managed Service Version + */ + managed_service_version?: number | null; + /** + * Uses Managed Service V2 + */ + uses_managed_service_v2: boolean; +}; + /** * OrganizationPreferences */ @@ -9750,6 +9926,45 @@ export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses = { export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponse = UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses[keyof UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses]; +export type GetCurrentOrganizationContextApiV1OrganizationsContextGetData = { + body?: never; + headers?: { + /** + * Authorization + */ + authorization?: string | null; + /** + * X-Api-Key + */ + 'X-API-Key'?: string | null; + }; + path?: never; + query?: never; + url: '/api/v1/organizations/context'; +}; + +export type GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors = { + /** + * Not found + */ + 404: unknown; + /** + * Validation Error + */ + 422: HttpValidationError; +}; + +export type GetCurrentOrganizationContextApiV1OrganizationsContextGetError = GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors[keyof GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors]; + +export type GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses = { + /** + * Successful Response + */ + 200: OrganizationContextResponse; +}; + +export type GetCurrentOrganizationContextApiV1OrganizationsContextGetResponse = GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses[keyof GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses]; + export type GetTelephonyProvidersMetadataApiV1OrganizationsTelephonyProvidersMetadataGetData = { body?: never; headers?: { @@ -11269,6 +11484,93 @@ export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses = { export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponse = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses]; +export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetData = { + body?: never; + headers?: { + /** + * Authorization + */ + authorization?: string | null; + /** + * X-Api-Key + */ + 'X-API-Key'?: string | null; + }; + path?: never; + query?: { + /** + * Page + */ + page?: number; + /** + * Limit + */ + limit?: number; + }; + url: '/api/v1/organizations/billing/credits'; +}; + +export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors = { + /** + * Not found + */ + 404: unknown; + /** + * Validation Error + */ + 422: HttpValidationError; +}; + +export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetError = GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors[keyof GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors]; + +export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses = { + /** + * Successful Response + */ + 200: MpsBillingCreditsResponse; +}; + +export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponse = GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses[keyof GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses]; + +export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostData = { + body?: never; + headers?: { + /** + * Authorization + */ + authorization?: string | null; + /** + * X-Api-Key + */ + 'X-API-Key'?: string | null; + }; + path?: never; + query?: never; + url: '/api/v1/organizations/usage/mps-credits/purchase-url'; +}; + +export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors = { + /** + * Not found + */ + 404: unknown; + /** + * Validation Error + */ + 422: HttpValidationError; +}; + +export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostError = CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors[keyof CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors]; + +export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses = { + /** + * Successful Response + */ + 200: MpsCreditPurchaseUrlResponse; +}; + +export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponse = CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses[keyof CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses]; + export type GetUsageHistoryApiV1OrganizationsUsageRunsGetData = { body?: never; headers?: { diff --git a/ui/src/components/AIModelConfigurationV2Editor.tsx b/ui/src/components/AIModelConfigurationV2Editor.tsx index 13ee2edd..bbe66658 100644 --- a/ui/src/components/AIModelConfigurationV2Editor.tsx +++ b/ui/src/components/AIModelConfigurationV2Editor.tsx @@ -17,7 +17,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages"; -type ModelMode = "dograh" | "byok"; +type ModelMode = "realtime" | "dograh" | "byok"; interface DograhDefaults { voices: string[]; @@ -125,24 +125,35 @@ function effectiveConfigToLegacyShape(config: Record | null): R }; } -function emptyByokInitialConfig(): Record { +function emptyByokInitialConfig(isRealtime: boolean): Record { return { - is_realtime: false, + is_realtime: isRealtime, }; } +// The v2 editor surfaces realtime ("Speech to Speech") and pipeline (BYOK) as +// separate tabs, so each tab gets its own initial config. A tab is pre-filled +// only when the saved (or effective) configuration matches that tab's mode; +// otherwise it starts empty so the other tab's data does not leak across. function getByokInitialConfig( configuration: Record | null, effectiveConfiguration: Record | null, + wantRealtime: boolean, ): Record { - const byokConfiguration = byokConfigToLegacyShape(configuration); - if (byokConfiguration) return byokConfiguration; + const matchesTab = (config: Record | null) => + config ? Boolean(config.is_realtime) === wantRealtime : false; - if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) { - return emptyByokInitialConfig(); + const byokConfiguration = byokConfigToLegacyShape(configuration); + if (byokConfiguration) { + return matchesTab(byokConfiguration) ? byokConfiguration : emptyByokInitialConfig(wantRealtime); } - return effectiveConfigToLegacyShape(effectiveConfiguration) || emptyByokInitialConfig(); + if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) { + return emptyByokInitialConfig(wantRealtime); + } + + const effective = effectiveConfigToLegacyShape(effectiveConfiguration); + return matchesTab(effective) ? (effective as Record) : emptyByokInitialConfig(wantRealtime); } function buildDograhState( @@ -185,10 +196,12 @@ function preferredMode( configuration: Record | null, effectiveConfiguration: Record | null, ): ModelMode { - if (configuration?.mode === "dograh" || configuration?.mode === "byok") { - return configuration.mode; + if (configuration?.mode === "dograh") return "dograh"; + if (configuration?.mode === "byok") { + return asRecord(configuration.byok)?.mode === "realtime" ? "realtime" : "byok"; } - return isDograhEffectiveConfig(effectiveConfiguration) ? "dograh" : "byok"; + if (isDograhEffectiveConfig(effectiveConfiguration)) return "dograh"; + return Boolean(effectiveConfiguration?.is_realtime) ? "realtime" : "byok"; } function hasRequiredApiKey( @@ -249,7 +262,8 @@ export function AIModelConfigurationV2Editor({ speed: defaults.dograh.defaults.speed, language: defaults.dograh.defaults.language, })); - const [byokInitialConfig, setByokInitialConfig] = useState | null>(null); + const [realtimeInitialConfig, setRealtimeInitialConfig] = useState | null>(null); + const [pipelineInitialConfig, setPipelineInitialConfig] = useState | null>(null); const [isSavingDograh, setIsSavingDograh] = useState(false); const [error, setError] = useState(null); @@ -258,7 +272,8 @@ export function AIModelConfigurationV2Editor({ const rawEffectiveConfiguration = asRecord(effectiveConfiguration); setMode(preferredMode(rawConfiguration, rawEffectiveConfiguration)); setDograh(buildDograhState(defaults, rawConfiguration, rawEffectiveConfiguration)); - setByokInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration)); + setRealtimeInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, true)); + setPipelineInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, false)); }, [configuration, defaults, effectiveConfiguration]); const saveDograhConfiguration = async () => { @@ -322,28 +337,30 @@ export function AIModelConfigurationV2Editor({ )} setMode(value as ModelMode)} className="space-y-6"> - + + Speech to Speech Dograh BYOK + +

+ A single speech-to-speech model handles the conversation in realtime (no separate transcriber or voice). An LLM is still required for variable extraction and QA. +

+ +
+
-
- -
- - setDograh({ ...dograh, api_key: event.target.value })} - placeholder="Enter API key" - /> -
-
-
+ +
+ +
+ + setDograh({ ...dograh, api_key: event.target.value })} + placeholder="Enter API key" + /> +
+