feat: billing and credit management v2 (#429)

* feat: use mps generated correlation ID

* chore: update pipecat submodule

* feat: add credit purchase URL

* feat: carve out billing page and show credit ledger

* feat: deprecate dograh based quota tracking

* fix: remove cost calculation from dograh codebase

* fix: create mps account on migrate to v2

* chore: update pipecat
This commit is contained in:
Abhishek 2026-06-12 14:55:30 +05:30 committed by GitHub
parent 97d7103480
commit 1f1149f4d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 3335 additions and 2057 deletions

View file

@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# 1) Create the `quota_type` enum *before* we add the column that references it.
@ -34,7 +37,12 @@ def upgrade() -> None:
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("period_start", sa.DateTime(), nullable=False),
sa.Column("period_end", sa.DateTime(), nullable=False),
sa.Column("quota_dograh_tokens", sa.Integer(), nullable=False),
sa.Column(
"quota_dograh_tokens",
sa.Integer(),
nullable=False,
comment=DEPRECATED_QUOTA_COMMENT,
),
sa.Column("used_dograh_tokens", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
@ -63,7 +71,11 @@ def upgrade() -> None:
op.add_column(
"organizations",
sa.Column(
"quota_type", quota_type_enum, nullable=False, server_default="monthly"
"quota_type",
quota_type_enum,
nullable=False,
server_default="monthly",
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
@ -73,6 +85,7 @@ def upgrade() -> None:
sa.Integer(),
nullable=False,
server_default=sa.text("0"),
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
@ -82,10 +95,17 @@ def upgrade() -> None:
sa.Integer(),
nullable=False,
server_default=sa.text("LEAST(EXTRACT(DAY FROM CURRENT_DATE)::int, 28)"),
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
"organizations", sa.Column("quota_start_date", sa.DateTime(), nullable=True)
"organizations",
sa.Column(
"quota_start_date",
sa.DateTime(),
nullable=True,
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
"organizations",
@ -94,6 +114,7 @@ def upgrade() -> None:
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
comment=DEPRECATED_QUOTA_COMMENT,
),
)
# ### end Alembic commands ###

View file

@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
@ -26,7 +29,12 @@ def upgrade() -> None:
)
op.add_column(
"organization_usage_cycles",
sa.Column("quota_amount_usd", sa.Float(), nullable=True),
sa.Column(
"quota_amount_usd",
sa.Float(),
nullable=True,
comment=DEPRECATED_QUOTA_COMMENT,
),
)
# ### end Alembic commands ###

View file

@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.workflow.run_usage_response import format_public_cost_info
class CampaignClient(BaseDBClient):
@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient):
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(
run.cost_info, run.usage_info
),
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient):
async with self.async_session() as session:
conditions = [
WorkflowRunModel.is_completed.is_(True),
WorkflowRunModel.cost_info["call_duration_seconds"]
WorkflowRunModel.usage_info["call_duration_seconds"]
.as_string()
.isnot(None),
]
@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient):
WorkflowRunModel.initial_context,
WorkflowRunModel.gathered_context,
WorkflowRunModel.cost_info,
WorkflowRunModel.usage_info,
WorkflowRunModel.public_access_token,
)
.where(*conditions)

View file

@ -53,7 +53,7 @@ class DBClient(
- UserClient: handles user and user configuration operations
- OrganizationClient: handles organization operations
- OrganizationConfigurationClient: handles organization configuration operations
- OrganizationUsageClient: handles organization usage and quota operations
- OrganizationUsageClient: handles organization usage reporting aggregates
- IntegrationClient: handles integration operations
- WorkflowTemplateClient: handles workflow template operations
- CampaignClient: handles campaign operations

View file

@ -25,7 +25,7 @@ def get_workflow_run_order_clause(
"""
# Determine sort column
if sort_by == "duration":
sort_column = WorkflowRunModel.cost_info.op("->>")(
sort_column = WorkflowRunModel.usage_info.op("->>")(
"call_duration_seconds"
).cast(Float)
else:
@ -43,7 +43,7 @@ def get_workflow_run_order_clause(
ATTRIBUTE_FIELD_MAPPING = {
"dateRange": "created_at",
"dispositionCode": "gathered_context.mapped_call_disposition",
"duration": "cost_info.call_duration_seconds",
"duration": "usage_info.call_duration_seconds",
"status": "is_completed",
"tokenUsage": "cost_info.total_cost_usd",
"runId": "id",
@ -208,7 +208,7 @@ def apply_workflow_run_filters(
min_val = value.get("min")
max_val = value.get("max")
if field == "cost_info.call_duration_seconds":
if field == "usage_info.call_duration_seconds":
# Use ->> operator for compatibility with all PostgreSQL versions
# (subscript [] only works in PostgreSQL 14+)
duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")(

View file

@ -97,22 +97,44 @@ class OrganizationModel(Base):
provider_id = Column(String, unique=True, index=True, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
# Quota fields
# Deprecated: MPS owns quota and credit ledger state.
quota_type = Column(
Enum("monthly", "annual", name="quota_type"),
nullable=False,
default="monthly",
server_default=text("'monthly'::quota_type"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_dograh_tokens = Column(
Integer, nullable=False, default=0, server_default=text("0")
Integer,
nullable=False,
default=0,
server_default=text("0"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_reset_day = Column(
Integer, nullable=False, default=1, server_default=text("1")
) # 1-28, only for monthly
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
Integer,
nullable=False,
default=1,
server_default=text("1"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_start_date = Column(
DateTime(timezone=True),
nullable=True,
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_enabled = Column(
Boolean, nullable=False, default=False, server_default=text("false")
Boolean,
nullable=False,
default=False,
server_default=text("false"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
price_per_second_usd = Column(Float, nullable=True)
@ -593,8 +615,9 @@ class WorkflowRunTextSessionModel(Base):
class OrganizationUsageCycleModel(Base):
"""
This model is used to track the usage of Dograh tokens for an organization for a given usage
cycle.
This model is used to track reporting aggregates for an organization for a given
usage cycle. Quota fields on this model are deprecated; MPS owns quota and
credit ledger state.
"""
__tablename__ = "organization_usage_cycles"
@ -603,14 +626,24 @@ class OrganizationUsageCycleModel(Base):
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
period_start = Column(DateTime(timezone=True), nullable=False)
period_end = Column(DateTime(timezone=True), nullable=False)
quota_dograh_tokens = Column(Integer, nullable=False)
quota_dograh_tokens = Column(
Integer,
nullable=False,
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
used_dograh_tokens = Column(Float, nullable=False, default=0)
total_duration_seconds = Column(
Integer, nullable=False, default=0, server_default=text("0")
)
# New USD tracking fields
used_amount_usd = Column(Float, nullable=True, default=0)
quota_amount_usd = Column(Float, nullable=True)
quota_amount_usd = Column(
Float,
nullable=True,
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),

View file

@ -19,11 +19,11 @@ from api.db.models import (
WorkflowRunModel,
)
from api.enums import OrganizationConfigurationKey
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
class OrganizationUsageClient(BaseDBClient):
"""Client for managing organization usage and quota operations."""
"""Client for managing organization usage reporting aggregates."""
async def get_or_create_current_cycle(
self, organization_id: int, session=None
@ -49,14 +49,7 @@ class OrganizationUsageClient(BaseDBClient):
self, organization_id: int, session, commit: bool
) -> OrganizationUsageCycleModel:
"""Internal implementation for get_or_create_current_cycle."""
# Get organization to determine quota type
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one()
# Calculate current period
period_start, period_end = self._calculate_current_period(org)
period_start, period_end = self._calculate_current_period()
# Try to get existing cycle
cycle_result = await session.execute(
@ -78,7 +71,8 @@ class OrganizationUsageClient(BaseDBClient):
organization_id=organization_id,
period_start=period_start,
period_end=period_end,
quota_dograh_tokens=org.quota_dograh_tokens,
# Deprecated non-null column retained for historical schema compatibility.
quota_dograh_tokens=0,
)
# Handle concurrent inserts gracefully
stmt = stmt.on_conflict_do_nothing(
@ -102,95 +96,9 @@ class OrganizationUsageClient(BaseDBClient):
)
return cycle_result.scalar_one()
async def check_and_reserve_quota(
self, organization_id: int, estimated_tokens: int = 0
) -> bool:
"""
Check if organization has sufficient quota and optionally reserve tokens.
Returns True if quota is available, False otherwise.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get organization
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one_or_none()
if not org or not org.quota_enabled:
# No quota enforcement if not enabled
return True
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Atomic check and update with row-level lock
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(
and_(
OrganizationUsageCycleModel.id == cycle.id,
OrganizationUsageCycleModel.used_dograh_tokens
+ estimated_tokens
<= OrganizationUsageCycleModel.quota_dograh_tokens,
)
)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one_or_none()
if cycle_locked:
# Update the usage atomically
cycle_locked.used_dograh_tokens += estimated_tokens
await session.commit()
return True
return False
async def update_usage_after_run(
self,
organization_id: int,
actual_tokens: float,
duration_seconds: float = 0,
charge_usd: float | None = None,
) -> None:
"""Update usage after a workflow run completes with actual token count and duration.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Acquire a row-level lock for atomic update
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(OrganizationUsageCycleModel.id == cycle.id)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one()
# Update usage atomically
cycle_locked.used_dograh_tokens += actual_tokens
cycle_locked.total_duration_seconds += int(round(duration_seconds))
# Update USD amount if provided
if charge_usd is not None:
if cycle_locked.used_amount_usd is None:
cycle_locked.used_amount_usd = 0
cycle_locked.used_amount_usd += charge_usd
await session.commit()
async def get_current_usage(self, organization_id: int) -> dict:
"""Get current period usage information."""
"""Get current reporting-period usage information."""
async with self.async_session() as session:
# Get organization
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
@ -201,42 +109,19 @@ class OrganizationUsageClient(BaseDBClient):
organization_id, session, commit=False
)
# Calculate next refresh date
if org.quota_type == "monthly":
next_refresh = cycle.period_end + relativedelta(days=1)
else: # annual
next_refresh = cycle.period_end + relativedelta(days=1)
result = {
"period_start": cycle.period_start.isoformat(),
"period_end": cycle.period_end.isoformat(),
"used_dograh_tokens": cycle.used_dograh_tokens,
"quota_dograh_tokens": cycle.quota_dograh_tokens,
"percentage_used": (
round(
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
)
if cycle.quota_dograh_tokens > 0
else 0
),
"next_refresh_date": next_refresh.date().isoformat(),
"quota_enabled": org.quota_enabled,
"total_duration_seconds": cycle.total_duration_seconds,
}
# Add USD fields if organization has pricing
if org.price_per_second_usd is not None:
result["used_amount_usd"] = cycle.used_amount_usd or 0
result["quota_amount_usd"] = cycle.quota_amount_usd
result["currency"] = "USD"
result["price_per_second_usd"] = org.price_per_second_usd
# Calculate percentage based on USD if available
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
result["percentage_used"] = round(
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
)
return result
async def get_usage_history(
@ -256,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient):
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
WorkflowRunModel.usage_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
@ -309,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient):
total_tokens = 0
total_duration_seconds = 0
for run in runs:
if run.cost_info:
# Try to get dograh_token_usage first (new format)
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
# If not present, calculate from total_cost_usd (old format)
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
dograh_tokens = round(
float(run.cost_info["total_cost_usd"]) * 100, 2
)
# Get call duration
call_duration = run.cost_info.get("call_duration_seconds", 0)
else:
dograh_tokens = 0
call_duration = 0
dograh_tokens = 0
call_duration = (run.usage_info or {}).get("call_duration_seconds", 0)
total_tokens += dograh_tokens
total_duration_seconds += int(round(call_duration))
@ -395,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient):
WorkflowRunModel.initial_context,
WorkflowRunModel.gathered_context,
WorkflowRunModel.cost_info,
WorkflowRunModel.usage_info,
WorkflowRunModel.public_access_token,
)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
WorkflowRunModel.usage_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
@ -473,11 +348,11 @@ class OrganizationUsageClient(BaseDBClient):
)
config_obj = config_result.scalar_one_or_none()
if config_obj and config_obj.configuration:
user_config = EffectiveAIModelConfiguration.model_validate(
effective_config = EffectiveAIModelConfiguration.model_validate(
config_obj.configuration
)
if user_config.timezone and user_timezone == "UTC":
user_timezone = user_config.timezone
if effective_config.timezone and user_timezone == "UTC":
user_timezone = effective_config.timezone
# Validate timezone string
try:
@ -496,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient):
select(
date_expr.label("date"),
func.sum(
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
WorkflowRunModel.usage_info["call_duration_seconds"].as_float()
).label("total_seconds"),
func.count(WorkflowRunModel.id).label("call_count"),
)
@ -545,83 +420,11 @@ class OrganizationUsageClient(BaseDBClient):
"currency": "USD",
}
async def update_organization_quota(
self,
organization_id: int,
quota_type: str,
quota_dograh_tokens: int,
quota_reset_day: Optional[int] = None,
quota_start_date: Optional[datetime] = None,
) -> OrganizationModel:
"""Update organization quota settings."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = result.scalar_one()
org.quota_type = quota_type
org.quota_dograh_tokens = quota_dograh_tokens
org.quota_enabled = True
if quota_type == "monthly" and quota_reset_day:
org.quota_reset_day = quota_reset_day
elif quota_type == "annual" and quota_start_date:
org.quota_start_date = quota_start_date
await session.commit()
await session.refresh(org)
return org
def _calculate_current_period(
self, org: OrganizationModel
) -> tuple[datetime, datetime]:
"""Calculate the current billing period based on organization settings."""
def _calculate_current_period(self) -> tuple[datetime, datetime]:
"""Calculate the current calendar-month reporting period."""
now = datetime.now(timezone.utc)
if org.quota_type == "monthly":
# Find the start of the current billing month
reset_day = org.quota_reset_day
# Handle month boundaries
if now.day >= reset_day:
period_start = now.replace(
day=reset_day, hour=0, minute=0, second=0, microsecond=0
)
else:
# Previous month
period_start = (now - relativedelta(months=1)).replace(
day=reset_day, hour=0, minute=0, second=0, microsecond=0
)
# End is one month later minus 1 second
period_end = (
period_start + relativedelta(months=1) - relativedelta(seconds=1)
)
else: # annual
if not org.quota_start_date:
# Default to calendar year
period_start = now.replace(
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
period_end = (
period_start + relativedelta(years=1) - relativedelta(seconds=1)
)
else:
# Find current annual period
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
years_diff = now.year - start_date.year
# Adjust for whether we've passed the anniversary
if now.month < start_date.month or (
now.month == start_date.month and now.day < start_date.day
):
years_diff -= 1
period_start = start_date + relativedelta(years=years_diff)
period_end = (
period_start + relativedelta(years=1) - relativedelta(seconds=1)
)
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + relativedelta(months=1) - relativedelta(seconds=1)
return period_start, period_end

View file

@ -8,7 +8,7 @@ from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import UserConfigurationModel, UserModel
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
class UserClient(BaseDBClient):

View file

@ -16,6 +16,7 @@ from api.db.models import (
)
from api.enums import CallType, StorageBackend
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.workflow.run_usage_response import format_public_cost_info
class WorkflowRunClient(BaseDBClient):
@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient):
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(
run.cost_info, run.usage_info
),
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,

View file

@ -384,7 +384,7 @@ async def search_chunks(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
effective_config = resolved_config.effective
embeddings_api_key = None
embeddings_model = None
embeddings_provider = None
@ -392,17 +392,17 @@ async def search_chunks(
embeddings_endpoint = None
embeddings_api_version = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
if effective_config.embeddings:
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
# Initialize embedding service based on provider

View file

@ -5,7 +5,11 @@ from loguru import logger
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
from api.constants import (
DEFAULT_CAMPAIGN_RETRY_CONFIG,
DEFAULT_ORG_CONCURRENCY_LIMIT,
DEPLOYMENT_MODE,
)
from api.db import db_client
from api.db.models import UserModel
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
@ -55,6 +59,11 @@ from api.services.configuration.registry import (
ServiceProviders,
ServiceType,
)
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.organization_context import (
OrganizationContextResponse,
get_organization_context,
)
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
@ -129,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
telnyx_missing_webhook_public_key_count: int
@router.get("/context", response_model=OrganizationContextResponse)
async def get_current_organization_context(user: UserModel = Depends(get_user)):
"""Return organization-scoped configuration signals owned by Dograh."""
return await get_organization_context(user)
@router.get(
"/telephony-providers/metadata",
response_model=TelephonyProvidersMetadataResponse,
@ -349,6 +364,23 @@ async def migrate_model_configuration_v2(
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
if DEPLOYMENT_MODE != "oss":
try:
await ensure_hosted_mps_billing_account_v2(
organization_id,
created_by=str(user.provider_id),
)
except Exception as exc:
logger.error(
"Failed to initialize MPS billing v2 account for organization {}: {}",
organization_id,
exc,
)
raise HTTPException(
status_code=502,
detail="Failed to initialize MPS billing v2 account",
)
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,

View file

@ -1,16 +1,16 @@
import json
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from loguru import logger
from pydantic import BaseModel, Field
from api.constants import DEPLOYMENT_MODE
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.mps_service_key_client import mps_service_key_client
from api.services.reports import generate_usage_runs_report_csv
from api.utils.artifacts import artifact_url
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
period_start: str
period_end: str
used_dograh_tokens: float
quota_dograh_tokens: int
percentage_used: float
next_refresh_date: str
quota_enabled: bool
total_duration_seconds: int
# New USD fields
used_amount_usd: Optional[float] = None
quota_amount_usd: Optional[float] = None
currency: Optional[str] = None
price_per_second_usd: Optional[float] = None
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
total_quota: float
class MPSCreditPurchaseUrlResponse(BaseModel):
checkout_url: str
class MPSBillingAccountResponse(BaseModel):
id: int
organization_id: int
billing_mode: str
cached_balance_credits: float
currency: str
class MPSCreditLedgerEntryResponse(BaseModel):
id: int
entry_type: str
origin: Optional[str] = None
credits_delta: float
balance_after: float
amount_minor: Optional[int] = None
amount_currency: Optional[str] = None
payment_order_id: Optional[int] = None
metric_code: Optional[str] = None
correlation_id: Optional[str] = None
aggregation_key: Optional[str] = None
usage_event_id: Optional[int] = None
workflow_run_id: Optional[int] = None
workflow_id: Optional[int] = None
billable_quantity: Optional[float] = None
quantity_unit: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: str
class MPSBillingCreditsResponse(BaseModel):
billing_version: Literal["legacy", "v2"]
total_credits_used: float = 0.0
remaining_credits: float = 0.0
total_quota: float = 0.0
account: Optional[MPSBillingAccountResponse] = None
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
total_count: int = 0
page: int = 1
limit: int = 50
total_pages: int = 0
def _optional_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
class WorkflowRunUsageResponse(BaseModel):
id: int
workflow_id: int
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
async def get_current_period_usage(user: UserModel = Depends(get_user)):
"""Get current billing period usage for the user's organization."""
"""Get current reporting-period usage for the user's organization."""
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
raise HTTPException(status_code=500, detail=str(e))
async def _get_mps_billing_account_status(
user: UserModel, organization_id: int
) -> Optional[dict]:
return await mps_service_key_client.get_billing_account_status(
organization_id=organization_id,
created_by=str(user.provider_id),
)
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
return bool(account and account.get("billing_mode") == "v2")
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
if DEPLOYMENT_MODE == "oss":
usage = await mps_service_key_client.get_usage_by_created_by(
str(user.provider_id)
)
else:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
usage = await mps_service_key_client.get_usage_by_organization(
user.selected_organization_id
)
total_used = float(usage.get("total_credits_used", 0.0))
total_remaining = float(usage.get("remaining_credits", 0.0))
return MPSBillingCreditsResponse(
billing_version="legacy",
total_credits_used=total_used,
remaining_credits=total_remaining,
total_quota=total_used + total_remaining,
)
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
async def get_billing_credits(
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
user: UserModel = Depends(get_user),
):
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
try:
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
return await _legacy_mps_credits_response(user)
organization_id = user.selected_organization_id
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
return await _legacy_mps_credits_response(user)
ledger = await mps_service_key_client.get_credit_ledger(
organization_id=organization_id,
page=page,
limit=limit,
created_by=str(user.provider_id),
)
account = ledger.get("account") or {}
ledger_entries = ledger.get("ledger_entries") or []
total_count = int(ledger.get("total_count") or len(ledger_entries))
response_limit = int(ledger.get("limit") or limit)
total_pages = int(
ledger.get("total_pages")
or ((total_count + response_limit - 1) // response_limit)
)
workflow_ids_by_run_id: dict[int, int] = {}
workflow_run_ids = {
workflow_run_id
for entry in ledger_entries
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
is not None
}
for workflow_run_id in workflow_run_ids:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if (
workflow_run
and workflow_run.workflow
and workflow_run.workflow.organization_id == organization_id
):
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
balance = float(account.get("cached_balance_credits") or 0.0)
total_debits = sum(
abs(float(entry.get("credits_delta") or 0.0))
for entry in ledger_entries
if float(entry.get("credits_delta") or 0.0) < 0
)
if ledger.get("total_debits_credits") is not None:
total_debits = float(ledger["total_debits_credits"])
return MPSBillingCreditsResponse(
billing_version="v2",
total_credits_used=total_debits,
remaining_credits=balance,
total_quota=balance + total_debits,
account=MPSBillingAccountResponse(
id=int(account["id"]),
organization_id=int(account["organization_id"]),
billing_mode=str(account["billing_mode"]),
cached_balance_credits=balance,
currency=str(account.get("currency") or "USD"),
),
ledger_entries=[
MPSCreditLedgerEntryResponse(
id=int(entry["id"]),
entry_type=str(entry["entry_type"]),
origin=entry.get("origin"),
credits_delta=float(entry.get("credits_delta") or 0.0),
balance_after=float(entry.get("balance_after") or 0.0),
amount_minor=entry.get("amount_minor"),
amount_currency=entry.get("amount_currency"),
payment_order_id=entry.get("payment_order_id"),
metric_code=entry.get("metric_code"),
correlation_id=entry.get("correlation_id"),
aggregation_key=entry.get("aggregation_key"),
usage_event_id=_optional_int(entry.get("usage_event_id")),
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
workflow_id=workflow_ids_by_run_id.get(
_optional_int(entry.get("workflow_run_id"))
)
if entry.get("workflow_run_id") is not None
else None,
billable_quantity=float(entry["billable_quantity"])
if entry.get("billable_quantity") is not None
else None,
quantity_unit=entry.get("quantity_unit"),
metadata=entry.get("metadata") or {},
created_at=str(entry["created_at"]),
)
for entry in ledger_entries
],
total_count=total_count,
page=int(ledger.get("page") or page),
limit=response_limit,
total_pages=total_pages,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to fetch billing credits: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post(
"/usage/mps-credits/purchase-url",
response_model=MPSCreditPurchaseUrlResponse,
)
async def create_mps_credit_purchase_url(
user: UserModel = Depends(get_user_with_selected_organization),
):
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
if DEPLOYMENT_MODE == "oss":
raise HTTPException(
status_code=404,
detail="Credit purchases are not available in OSS mode",
)
organization_id = user.selected_organization_id
assert organization_id is not None
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
raise HTTPException(
status_code=403,
detail=(
"Credit purchases are available only for organizations using billing v2"
),
)
try:
session = await mps_service_key_client.create_credit_purchase_url(
organization_id=organization_id,
created_by=str(user.provider_id),
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
billing_details={
"source": "dograh_billing",
"dograh_user_id": str(user.id),
"dograh_provider_id": str(user.provider_id),
},
)
except Exception as exc:
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
raise HTTPException(
status_code=502,
detail="Failed to create credit purchase URL",
)
checkout_url = session.get("checkout_url")
if not checkout_url:
logger.error(f"MPS checkout session response missing checkout_url: {session}")
raise HTTPException(
status_code=502,
detail="MPS checkout session response missing checkout_url",
)
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
FILTERS_DESCRIPTION = """\
JSON-encoded array of filter objects. Each object has the shape:

View file

@ -41,12 +41,15 @@ from api.services.configuration.resolve import (
)
from api.services.mps_service_key_client import mps_service_key_client
from api.services.posthog_client import capture_event
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.reports import generate_workflow_report_csv
from api.services.storage import storage_fs
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
from api.services.workflow.duplicate import duplicate_workflow
from api.services.workflow.errors import ItemKind, WorkflowError
from api.services.workflow.run_usage_response import (
format_public_cost_info,
format_public_usage_info,
)
from api.services.workflow.trigger_paths import (
TriggerPathIssue,
ensure_trigger_paths,
@ -1053,13 +1056,15 @@ async def update_workflow(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
effective_config = resolved_config.effective
try:
enriched_overrides = enrich_overrides_with_api_keys(
workflow_configurations["model_overrides"],
user_config,
effective_config,
)
effective = resolve_effective_config(
effective_config, enriched_overrides
)
effective = resolve_effective_config(user_config, enriched_overrides)
if resolved_config.source == "organization_v2":
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
await UserConfigurationValidator().validate(
@ -1264,22 +1269,7 @@ async def get_workflow_run(
"transcript_public_url": artifact_url(public_access_token, "transcript"),
"recording_public_url": artifact_url(public_access_token, "recording"),
"public_access_token": public_access_token,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info and "dograh_token_usage" in run.cost_info
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds"))
)
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
"usage_info": format_public_usage_info(run.usage_info),
"created_at": run.created_at,
"definition_id": run.definition_id,

View file

@ -1,10 +1,10 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
DograhEmbeddingsConfiguration,
DograhLLMService,
@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default"
DOGRAH_DEFAULT_LANGUAGE = "multi"
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
realtime: RealtimeConfig | None = None
is_realtime: bool = False
managed_service_version: int | None = None
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None
@model_validator(mode="before")
@classmethod
def strip_incomplete_realtime_when_disabled(cls, data):
"""Skip realtime validation when is_realtime is False and api_key is missing."""
if isinstance(data, dict) and not data.get("is_realtime", False):
realtime = data.get("realtime")
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data
class DograhManagedAIModelConfiguration(BaseModel):
api_key: str
voice: str = DOGRAH_DEFAULT_VOICE
@ -160,6 +183,7 @@ def _compile_dograh_configuration(
model="default",
),
is_realtime=False,
managed_service_version=2,
)

View file

@ -1,33 +0,0 @@
from datetime import datetime
from pydantic import BaseModel, model_validator
from api.services.configuration.registry import (
EmbeddingsConfig,
LLMConfig,
RealtimeConfig,
STTConfig,
TTSConfig,
)
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
realtime: RealtimeConfig | None = None
is_realtime: bool = False
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None
@model_validator(mode="before")
@classmethod
def strip_incomplete_realtime_when_disabled(cls, data):
"""Skip realtime validation when is_realtime is False and api_key is missing."""
if isinstance(data, dict) and not data.get("is_realtime", False):
realtime = data.get("realtime")
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data

View file

@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.posthog_client import capture_event
from api.utils.auth import decode_jwt_token
@ -110,6 +111,19 @@ async def get_user(
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
try:
await ensure_hosted_mps_billing_account_v2(
organization.id,
created_by=str(stack_user["id"]),
)
except Exception:
logger.warning(
"Failed to initialize hosted MPS billing account for "
"organization {}",
organization.id,
exc_info=True,
)
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users
@ -285,8 +299,8 @@ async def create_user_configuration_with_mps_key(
"model": "default",
},
}
user_config = EffectiveAIModelConfiguration(**configuration)
return user_config
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"

View file

@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,

View file

@ -8,7 +8,7 @@ from groq import Groq
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
from api.schemas.ai_model_configuration import (
EffectiveAIModelConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
@ -75,21 +75,21 @@ class UserConfigurationValidator:
status_list = []
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
# Realtime is optional - only validate if is_realtime is enabled
if configuration.is_realtime:
status_list.extend(
self._validate_service(
configuration.realtime, "realtime", required=True
)
)
else:
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)

View file

@ -12,7 +12,7 @@ The rules are simple:
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceConfig
from api.services.integrations import get_node_secret_fields

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
import copy
from typing import Dict
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import copy
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
REGISTRY,
ServiceType,

View file

@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the OpenAI embedding service.
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
field_name="base_url",
)
client_kwargs["base_url"] = base_url
if default_headers:
client_kwargs["default_headers"] = default_headers
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:

View file

@ -0,0 +1,98 @@
from __future__ import annotations
from typing import Any
from loguru import logger
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
def uses_managed_model_services_v2(
ai_model_config: EffectiveAIModelConfiguration | None,
) -> bool:
if (
ai_model_config is None
or getattr(ai_model_config, "managed_service_version", None) != 2
):
return False
return any(
_is_dograh_service(getattr(ai_model_config, section_name, None))
for section_name in ("llm", "tts", "stt", "embeddings")
)
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
if not initial_context:
return None
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
if correlation_id is None:
return None
return str(correlation_id)
async def ensure_mps_correlation_id(
*,
ai_model_config: EffectiveAIModelConfiguration,
workflow_run_id: int,
initial_context: dict[str, Any] | None,
) -> str | None:
existing = get_mps_correlation_id(initial_context)
if existing:
return existing
if not uses_managed_model_services_v2(ai_model_config):
return None
service_key = _get_dograh_service_api_key(ai_model_config)
if not service_key:
raise ValueError(
"Managed model services v2 requires a Dograh service key before the run starts."
)
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
correlation_id = response.get("correlation_id")
if not correlation_id:
raise ValueError("MPS correlation-id response did not include correlation_id")
correlation_id = str(correlation_id)
logger.info(
"Minted MPS correlation id {} for workflow run {}",
correlation_id,
workflow_run_id,
)
return correlation_id
def _is_dograh_service(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):
service = getattr(ai_model_config, section_name, None)
if not _is_dograh_service(service):
continue
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if keys:
return keys[0]
api_key = getattr(service, "api_key", None)
if isinstance(api_key, str) and api_key:
return api_key
return None

View file

@ -0,0 +1,23 @@
from typing import Optional
from api.constants import DEPLOYMENT_MODE
from api.services.mps_service_key_client import mps_service_key_client
async def ensure_hosted_mps_billing_account_v2(
organization_id: int,
*,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Ensure hosted orgs have an MPS billing v2 account.
OSS deployments use legacy per-key quota accounting and do not create MPS
billing accounts.
"""
if DEPLOYMENT_MODE == "oss":
return None
return await mps_service_key_client.ensure_billing_account_v2(
organization_id=organization_id,
created_by=created_by,
)

View file

@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
Service keys are stored and managed entirely in MPS, not in the local database.
"""
import asyncio
from typing import List, Optional
import httpx
@ -353,6 +354,234 @@ class MPSServiceKeyClient:
response=response,
)
async def create_credit_purchase_url(
self,
organization_id: int,
created_by: Optional[str] = None,
return_url: Optional[str] = None,
billing_details: Optional[dict] = None,
) -> dict:
"""Create a short-lived MPS checkout URL for adding organization credits."""
payload = {
"created_by": created_by,
"return_url": return_url,
"billing_details": billing_details or {},
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create MPS credit purchase URL: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create MPS credit purchase URL: {response.text}",
request=response.request,
response=response,
)
async def get_credit_ledger(
self,
organization_id: int,
page: int = 1,
limit: int = 50,
created_by: Optional[str] = None,
) -> dict:
"""Get the MPS v2 billing account balance and recent credit ledger."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
params={"page": page, "limit": limit},
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS credit ledger: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS credit ledger: {response.text}",
request=response.request,
response=response,
)
async def get_billing_account_status(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Get an existing MPS v2 billing account without creating one."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS billing account status: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS billing account status: {response.text}",
request=response.request,
response=response,
)
async def ensure_billing_account_v2(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Create or return the MPS v2 billing account for an organization."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to ensure MPS billing account v2: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to ensure MPS billing account v2: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,
service_key: str,
workflow_run_id: int | None = None,
) -> dict:
"""Mint a server-generated correlation ID for managed model services."""
payload: dict[str, int] = {}
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
json=payload,
headers={
"Authorization": f"Bearer {service_key}",
"Content-Type": "application/json",
},
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create correlation ID: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create correlation ID: {response.text}",
request=response.request,
response=response,
)
async def report_platform_usage(
self,
*,
organization_id: int,
correlation_id: Optional[str] = None,
duration_seconds: Optional[float] = None,
workflow_run_id: int | None = None,
metadata: Optional[dict] = None,
max_attempts: int = 3,
) -> dict:
"""Report hosted Dograh platform usage for a completed workflow run."""
if DEPLOYMENT_MODE == "oss":
raise ValueError("OSS deployments must not report platform usage to MPS")
if not correlation_id and duration_seconds is None:
raise ValueError(
"Platform usage reports require correlation_id or duration_seconds"
)
payload: dict = {
"metadata": metadata or {},
}
if correlation_id:
payload["correlation_id"] = correlation_id
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
max_attempts = max(1, max_attempts)
last_response: httpx.Response | None = None
async with httpx.AsyncClient(timeout=self.timeout) as client:
for attempt in range(1, max_attempts + 1):
response = await client.post(
(
f"{self.base_url}/api/v1/billing/accounts/"
f"{organization_id}/platform-usage"
),
json=payload,
headers=self._get_headers(organization_id=organization_id),
)
last_response = response
if response.status_code == 200:
return response.json()
should_retry = (
response.status_code == 409
and "usage_not_ready" in response.text
and attempt < max_attempts
)
if should_retry:
await asyncio.sleep(attempt)
continue
logger.error(
"Failed to report platform usage: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to report platform usage: {response.text}",
request=response.request,
response=response,
)
raise httpx.HTTPStatusError(
"Failed to report platform usage",
request=last_response.request,
response=last_response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -0,0 +1,50 @@
from typing import Literal, Optional
from pydantic import BaseModel
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
class OrganizationModelServicesContext(BaseModel):
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
has_model_configuration_v2: bool
managed_service_version: Optional[int] = None
uses_managed_service_v2: bool
class OrganizationContextResponse(BaseModel):
organization_id: Optional[int] = None
organization_provider_id: Optional[str] = None
model_services: OrganizationModelServicesContext
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
organization_id = user.selected_organization_id
organization = (
await db_client.get_organization_by_id(organization_id)
if organization_id
else None
)
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=organization_id,
)
managed_service_version = resolved.effective.managed_service_version
return OrganizationContextResponse(
organization_id=organization_id,
organization_provider_id=organization.provider_id if organization else None,
model_services=OrganizationModelServicesContext(
config_source=resolved.source,
has_model_configuration_v2=resolved.source == "organization_v2",
managed_service_version=managed_service_version,
uses_managed_service_v2=(
resolved.source == "organization_v2" and managed_service_version == 2
),
),
)

View file

@ -162,15 +162,13 @@ async def run_pipeline_telephony(
workflow_id: Workflow being executed.
workflow_run_id: Workflow run row.
user_id: Owner of the workflow.
call_id: Provider call identifier (stored in cost_info for billing).
call_id: Provider call identifier.
transport_kwargs: Provider-specific kwargs forwarded to the transport
factory (e.g. stream_sid + call_sid for Twilio).
"""
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
set_current_run_id(workflow_run_id)
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
workflow = await db_client.get_workflow(workflow_id, user_id)
if workflow:
set_current_org_id(workflow.organization_id)
@ -340,7 +338,7 @@ async def _run_pipeline(
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
merged_call_context_vars = dict(workflow_run.initial_context or {})
# If there is some extra call_context_vars, fold them in. Persistence
# happens once below, after runtime_configuration is also resolved.
if call_context_vars:
@ -398,6 +396,19 @@ async def _run_pipeline(
else:
user_config = resolved_user_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=merged_call_context_vars,
)
if mps_correlation_id:
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
is_realtime = user_config.is_realtime and user_config.realtime is not None
@ -409,11 +420,23 @@ async def _run_pipeline(
# Realtime services don't implement run_inference, so create a
# separate text LLM for variable extraction and other out-of-band
# inference calls.
inference_llm = create_llm_service(user_config)
inference_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
stt = create_stt_service(
user_config,
audio_config,
keyterms=keyterms,
correlation_id=mps_correlation_id,
)
tts = create_tts_service(
user_config,
audio_config,
correlation_id=mps_correlation_id,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = None
# Stamp the providers/models actually resolved for this run onto
@ -695,7 +718,10 @@ async def _run_pipeline(
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
voicemail_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),

View file

@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
user_config,
audio_config: "AudioConfig",
keyterms: list[str] | None = None,
correlation_id: str | None = None,
):
"""Create and return appropriate STT service based on user configuration
@ -160,6 +163,7 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
correlation_id=correlation_id,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
@ -286,7 +290,9 @@ def create_stt_service(
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_tts_service(
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
):
"""Create and return appropriate TTS service based on user configuration
Args:
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
correlation_id=correlation_id,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
model: str,
api_key: str | None,
*,
correlation_id: str | None = None,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=api_key,
correlation_id=correlation_id,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
def create_llm_service(user_config, correlation_id: str | None = None):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
@ -880,4 +889,10 @@ def create_llm_service(user_config):
elif provider == ServiceProviders.SARVAM.value:
kwargs["temperature"] = user_config.llm.temperature
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
return create_llm_service_from_provider(
provider,
model,
api_key,
correlation_id=correlation_id,
**kwargs,
)

View file

@ -1,76 +0,0 @@
# Pricing Module
This module contains pricing models and registries for different AI services used in workflow cost calculations.
## Structure
```
pricing/
├── __init__.py # Main module exports
├── models.py # Base pricing model classes
├── llm.py # LLM pricing configurations
├── tts.py # TTS pricing configurations
├── stt.py # STT pricing configurations
├── registry.py # Combined pricing registry
└── README.md # This file
```
## Pricing Models
### TokenPricingModel
Used for LLM services that charge based on tokens:
- `prompt_token_price`: Cost per prompt token
- `completion_token_price`: Cost per completion token
- `cache_read_discount`: Discount for cache read tokens (default 50%)
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
### CharacterPricingModel
Used for TTS services that charge based on character count:
- `character_price`: Cost per character
### TimePricingModel
Used for STT services that charge based on time:
- `second_price`: Cost per second
## Adding New Pricing
### Adding a New LLM Model
Edit `llm.py` and add the model to the appropriate provider:
```python
ServiceProviders.OPENAI: {
"new-model": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000,
completion_token_price=Decimal("8.00") / 1000000,
),
# ... existing models
}
```
### Adding a New Provider
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
2. The registry will automatically include them
### Adding a New Service Type
1. Create a new pricing file (e.g., `image.py`)
2. Define the pricing models
3. Import and add to `registry.py`
## Usage
The pricing registry is automatically imported and used by the cost calculator:
```python
from api.services.pricing import PRICING_REGISTRY
from api.services.workflow.cost_calculator import cost_calculator
# The cost calculator uses the pricing registry automatically
result = cost_calculator.calculate_total_cost(usage_info)
```
## Maintenance
- Update pricing when providers change their rates
- All prices should use `Decimal` for precision
- Include comments with current pricing from provider documentation
- Test changes with existing test suite

View file

@ -1,9 +0,0 @@
"""
Pricing module for workflow cost calculation.
This module contains pricing models and registries for different AI services.
"""
from .registry import PRICING_REGISTRY
__all__ = ["PRICING_REGISTRY"]

View file

@ -1,228 +0,0 @@
"""
Cost Calculator for Workflow Runs
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
Features:
- Token-based pricing for LLM services with cache optimization support
- Character-based pricing for TTS services
- Time-based pricing for STT services
- Configurable pricing models that can be updated
- Support for multiple providers and models
- Automatic provider inference from model names
- JSON serialization support for database storage
Usage:
from api.tasks.cost_calculator import cost_calculator
usage_info = {
"llm": {
"processor_name|||gpt-4o": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"total_tokens": 1500,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0
}
},
"tts": {
"processor_name|||aura-2-helena-en": 2000 # character count
}
}
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
print(f"Total cost: ${cost_breakdown['total']:.6f}")
"""
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
from api.services.configuration.registry import ServiceProviders
from api.services.pricing import PRICING_REGISTRY
from api.services.pricing.models import (
PricingModel,
)
class CostCalculator:
"""Main cost calculator class"""
def __init__(self, pricing_registry: Dict = None):
self.pricing_registry = pricing_registry or PRICING_REGISTRY
def get_pricing_model(
self, service_type: str, provider: str, model: str
) -> Optional[PricingModel]:
"""Get pricing model for a specific service, provider, and model"""
try:
service_pricing = self.pricing_registry.get(service_type, {})
# Try to get pricing for the specific provider
provider_pricing = service_pricing.get(provider, {})
pricing_model = provider_pricing.get(model) or provider_pricing.get(
"default"
)
if pricing_model:
return pricing_model
# If not found, try the "default" provider for this service type
default_provider_pricing = service_pricing.get("default", {})
return default_provider_pricing.get(model) or default_provider_pricing.get(
"default"
)
except (KeyError, AttributeError):
return None
def calculate_llm_cost(
self, provider: str, model: str, usage: Dict[str, int]
) -> Decimal:
"""Calculate cost for LLM usage"""
pricing_model = self.get_pricing_model("llm", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(usage)
def calculate_tts_cost(
self, provider: str, model: str, character_count: int
) -> Decimal:
"""Calculate cost for TTS usage"""
pricing_model = self.get_pricing_model("tts", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(character_count)
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
"""Calculate cost for STT usage"""
pricing_model = self.get_pricing_model("stt", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(seconds)
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
llm_cost_total = Decimal("0")
tts_cost_total = Decimal("0")
stt_cost_total = Decimal("0")
# Calculate LLM costs
llm_usage = usage_info.get("llm", {})
for key, usage in llm_usage.items():
processor, model = self._parse_key(key)
# Try to determine provider from processor name or model
provider = self._infer_provider_from_model(model, "llm")
cost = self.calculate_llm_cost(provider, model, usage)
llm_cost_total += cost
# Calculate TTS costs
tts_usage = usage_info.get("tts", {})
for key, character_count in tts_usage.items():
processor, model = self._parse_key(key)
# Handle the case where model is "None" - infer from processor
if model.lower() in ["none", "null", ""]:
provider = self._infer_provider_from_processor(processor, "tts")
model = "default" # Use default model for the provider
else:
provider = self._infer_provider_from_model(model, "tts")
cost = self.calculate_tts_cost(provider, model, character_count)
tts_cost_total += cost
# Calculate STT costs from explicit stt usage
stt_usage = usage_info.get("stt", {})
for key, seconds in stt_usage.items():
processor, model = self._parse_key(key)
provider = self._infer_provider_from_model(model, "stt")
cost = self.calculate_stt_cost(provider, model, seconds)
stt_cost_total += cost
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
return {
"llm_cost": float(llm_cost_total),
"tts_cost": float(tts_cost_total),
"stt_cost": float(stt_cost_total),
"total": float(total_cost),
}
def _parse_key(self, key) -> Tuple[str, str]:
"""Parse key which is in format 'processor|||model'"""
if isinstance(key, str) and "|||" in key:
parts = key.split("|||", 1)
return parts[0], parts[1]
else:
# Fallback for backwards compatibility or malformed keys
return str(key), "unknown"
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
"""Infer provider from model name"""
if not model:
return "unknown"
model_lower = model.lower()
# OpenAI models
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
return ServiceProviders.OPENAI
# Groq models
if any(keyword in model_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Elevenlabs models
if any(keyword in model_lower for keyword in ["eleven"]):
return ServiceProviders.ELEVENLABS
# Deepgram models
if any(
keyword in model_lower
for keyword in ["deepgram", "nova", "phonecall", "general"]
):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
"""Infer provider from processor name"""
if not processor:
return "unknown"
processor_lower = processor.lower()
# OpenAI processors
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
return ServiceProviders.OPENAI
# Groq processors
if any(keyword in processor_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Deepgram processors
if any(keyword in processor_lower for keyword in ["deepgram"]):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def update_pricing(
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
):
"""Update pricing for a specific service/provider/model combination"""
if service_type not in self.pricing_registry:
self.pricing_registry[service_type] = {}
if provider not in self.pricing_registry[service_type]:
self.pricing_registry[service_type][provider] = {}
self.pricing_registry[service_type][provider][model] = pricing_model
# Global cost calculator instance
cost_calculator = CostCalculator()

View file

@ -1,44 +0,0 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -1,143 +0,0 @@
"""
LLM pricing models for different providers.
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TokenPricingModel
# LLM pricing registry
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-3.5-turbo": TokenPricingModel(
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
),
"gpt-4": TokenPricingModel(
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
),
"gpt-4.1": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
),
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
),
"gpt-4.1-nano": TokenPricingModel(
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
),
"gpt-4.5-preview": TokenPricingModel(
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
),
"gpt-4o": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
completion_token_price=Decimal("10.00")
/ 1000000, # $10.00 per 1M tokens - FIXED
),
"gpt-4o-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
),
"gpt-4o-mini": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
),
"gpt-4o-search-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-search-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"o1": TokenPricingModel(
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
),
"o1-pro": TokenPricingModel(
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
),
"o1-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o3": TokenPricingModel(
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
),
"o3-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o4-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"computer-use-preview": TokenPricingModel(
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
),
"gpt-image-1": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
),
"codex-mini-latest": TokenPricingModel(
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
),
# Transcription models
"gpt-4o-transcribe": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-transcribe": TokenPricingModel(
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
),
# TTS models with token-based pricing
"gpt-4o-mini-tts": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("0")
/ 1000000, # No completion tokens for TTS
),
},
ServiceProviders.GROQ: {
"llama-3.3-70b-versatile": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
),
"deepseek-r1-distill-llama-70b": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
completion_token_price=Decimal("0.00079") / 1000,
),
},
ServiceProviders.AZURE: {
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("8.80")
/ 1000000, # $1.60 per 1M tokens if using data zone
)
},
}

View file

@ -1,89 +0,0 @@
"""
Base pricing models for different service types.
"""
from decimal import Decimal
from enum import Enum
from typing import Any, Dict
class CostType(Enum):
LLM_TOKENS = "llm_tokens"
TTS_CHARACTERS = "tts_characters"
STT_SECONDS = "stt_seconds"
class PricingModel:
"""Base class for pricing models"""
def calculate_cost(self, usage: Any) -> Decimal:
"""Calculate cost based on usage"""
raise NotImplementedError
class TokenPricingModel(PricingModel):
"""Pricing model for token-based services (LLM)"""
def __init__(
self,
prompt_token_price: Decimal,
completion_token_price: Decimal,
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
cache_creation_multiplier: Decimal = Decimal(
"1.25"
), # 25% premium for cache creation
):
self.prompt_token_price = prompt_token_price
self.completion_token_price = completion_token_price
self.cache_read_discount = cache_read_discount
self.cache_creation_multiplier = cache_creation_multiplier
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
"""Calculate cost for LLM token usage"""
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
# Base cost
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
completion_cost = Decimal(completion_tokens) * self.completion_token_price
# Cache adjustments
cache_read_savings = (
Decimal(cache_read_tokens)
* self.prompt_token_price
* self.cache_read_discount
)
cache_creation_premium = (
Decimal(cache_creation_tokens)
* self.prompt_token_price
* (self.cache_creation_multiplier - 1)
)
total_cost = (
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
)
return max(total_cost, Decimal("0")) # Ensure non-negative
class CharacterPricingModel(PricingModel):
"""Pricing model for character-based services (TTS)"""
def __init__(self, character_price: Decimal):
self.character_price = character_price
def calculate_cost(self, character_count: int) -> Decimal:
"""Calculate cost for TTS character usage"""
return Decimal(character_count) * self.character_price
class TimePricingModel(PricingModel):
"""Pricing model for time-based services (STT)"""
def __init__(self, second_price: Decimal):
self.second_price = second_price
def calculate_cost(self, seconds: float) -> Decimal:
"""Calculate cost for STT time usage"""
return Decimal(str(seconds)) * self.second_price

View file

@ -1,18 +0,0 @@
"""
Main pricing registry that combines all service type pricing models.
"""
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
# Combined pricing registry for all service types
PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -1,13 +0,0 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}

View file

@ -1,26 +0,0 @@
"""
STT (Speech-to-Text) pricing models for different providers.
Prices are per second for STT services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TimePricingModel
# STT pricing registry
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
ServiceProviders.DEEPGRAM: {
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
"default": TimePricingModel(Decimal("0.0077") / 60),
},
ServiceProviders.OPENAI: {
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
"default": TimePricingModel(Decimal("0.015") / 60),
},
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
}

View file

@ -1,30 +0,0 @@
"""
TTS (Text-to-Speech) pricing models for different providers.
Prices are per character for TTS services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import CharacterPricingModel
# TTS pricing registry
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
},
ServiceProviders.DEEPGRAM: {
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
},
ServiceProviders.ELEVENLABS: {
# 6400 usd per 250*1e6 characters
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
},
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
}

View file

@ -1,230 +0,0 @@
from decimal import Decimal
from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunMode
from api.services.pricing.cost_calculator import cost_calculator
from api.services.telephony.factory import get_telephony_provider_for_run
async def _fetch_telephony_cost(workflow_run) -> dict | None:
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
if (
workflow_run.mode
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
or not workflow_run.cost_info
):
return None
call_id = workflow_run.cost_info.get("call_id")
if not call_id:
logger.warning(f"call_id not found in cost_info")
return None
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning("Workflow not found for workflow run")
raise Exception("Workflow not found")
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
call_cost_info = await provider.get_call_cost(call_id)
if call_cost_info.get("status") == "error":
logger.error(
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
)
return None
cost_usd = call_cost_info.get("cost_usd", 0.0)
logger.info(
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
)
return {"cost_usd": cost_usd, "provider_name": provider_name}
async def _update_organization_usage(
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
) -> None:
"""Update organization usage after a workflow run."""
org_id = org.id
await db_client.update_usage_after_run(
org_id, dograh_tokens, duration_seconds, charge_usd
)
if charge_usd is not None:
logger.info(
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
)
else:
logger.info(
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
)
async def _get_pricing_organization(workflow_run):
workflow = getattr(workflow_run, "workflow", None)
organization_id = getattr(workflow, "organization_id", None)
if organization_id is None and workflow and workflow.user:
organization_id = workflow.user.selected_organization_id
if organization_id is None:
return None
return await db_client.get_organization_by_id(organization_id)
async def _build_usage_cost_snapshot(
usage_info: dict | None,
*,
workflow_run=None,
include_telephony_cost: bool = False,
organization=None,
calculated_at: str | None = None,
) -> dict | None:
if not usage_info:
logger.warning("No usage info available for workflow run")
return None
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
if include_telephony_cost and workflow_run is not None:
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
total_cost_usd = Decimal(str(cost_breakdown["total"]))
dograh_tokens = float(total_cost_usd * Decimal("100"))
if organization is None and workflow_run is not None:
organization = await _get_pricing_organization(workflow_run)
charge_usd = None
if organization and organization.price_per_second_usd:
duration_seconds = usage_info.get("call_duration_seconds", 0)
charge_usd = float(
Decimal(str(duration_seconds))
* Decimal(str(organization.price_per_second_usd))
)
cost_info = {
"cost_breakdown": cost_breakdown,
"total_cost_usd": float(total_cost_usd),
"dograh_token_usage": dograh_tokens,
"calculated_at": calculated_at
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
}
if charge_usd is not None:
cost_info["charge_usd"] = charge_usd
cost_info["price_per_second_usd"] = organization.price_per_second_usd
return cost_info
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
cost_info = await _build_usage_cost_snapshot(
workflow_run.usage_info,
workflow_run=workflow_run,
include_telephony_cost=True,
calculated_at=workflow_run.created_at.isoformat(),
)
if cost_info is None:
return None
return {
**(workflow_run.cost_info or {}),
**cost_info,
}
async def save_workflow_run_cost_info(
workflow_run_id: int, cost_info: dict | None
) -> None:
if cost_info is None:
return
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
async def apply_workflow_run_usage_to_organization(
workflow_run, cost_info: dict | None
) -> None:
if cost_info is None:
return
org = await _get_pricing_organization(workflow_run)
if not org:
return
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
async def apply_usage_delta_to_organization(
workflow_run, usage_info: dict | None
) -> dict | None:
org = await _get_pricing_organization(workflow_run)
if not org:
return None
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
if cost_info is None:
return None
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
return cost_info
async def calculate_workflow_run_cost(workflow_run_id: int):
logger.debug("Calculating cost for workflow run")
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning("Workflow run not found")
return
try:
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is None:
return
await save_workflow_run_cost_info(workflow_run_id, cost_info)
try:
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
except Exception as e:
org = await _get_pricing_organization(workflow_run)
if org:
logger.error(
f"Failed to update organization usage for org {org.id}: {e}"
)
else:
logger.error(f"Failed to update organization usage: {e}")
# Don't fail the whole cost calculation if usage update fails
logger.info(
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
)
except Exception as e:
logger.error(f"Error calculating cost for workflow run: {e}")
raise

View file

@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
for run in runs:
initial = run.initial_context or {}
gathered = run.gathered_context or {}
cost = run.cost_info or {}
usage = run.usage_info or {}
call_tags = gathered.get("call_tags", [])
if isinstance(call_tags, list):
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
run.created_at.isoformat() if run.created_at else "",
initial.get("phone_number", ""),
gathered.get("mapped_call_disposition", ""),
cost.get("call_duration_seconds", ""),
usage.get("call_duration_seconds", ""),
]
extracted = gathered.get("extracted_variables", {})

View file

@ -66,34 +66,6 @@ async def handle_vonage_events(
logger.error(f"[run {workflow_run_id}] Workflow run not found")
return {"status": "error", "message": "Workflow run not found"}
# For a completed call that includes cost info, capture it immediately
if event_data.get("status") == "completed":
# Vonage sometimes includes price info in the webhook
if "price" in event_data or "rate" in event_data:
try:
if workflow_run.cost_info:
# Store immediate cost info if available
cost_info = workflow_run.cost_info.copy()
if "price" in event_data:
cost_info["vonage_webhook_price"] = float(event_data["price"])
if "rate" in event_data:
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
if "duration" in event_data:
cost_info["vonage_webhook_duration"] = int(
event_data["duration"]
)
await db_client.update_workflow_run(
run_id=workflow_run_id, cost_info=cost_info
)
logger.info(
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
)
except Exception as e:
logger.error(
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
)
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:

View file

@ -35,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
@ -382,6 +383,9 @@ class PipecatEngine:
embeddings_provider=self._embeddings_provider,
embeddings_endpoint=self._embeddings_endpoint,
embeddings_api_version=self._embeddings_api_version,
correlation_id=self._call_context_vars.get(
MPS_CORRELATION_ID_CONTEXT_KEY
),
tracing_context=self._get_otel_context(),
)

View file

@ -0,0 +1,41 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}
def format_public_cost_info(
cost_info: dict | None, usage_info: dict | None
) -> dict | None:
"""Return the legacy response shape without doing local cost accounting."""
duration = None
if usage_info and usage_info.get("call_duration_seconds") is not None:
duration = int(round(usage_info.get("call_duration_seconds") or 0))
elif cost_info and cost_info.get("call_duration_seconds") is not None:
duration = int(round(cost_info.get("call_duration_seconds") or 0))
dograh_token_usage = 0
if cost_info:
if "dograh_token_usage" in cost_info:
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
elif "total_cost_usd" in cost_info:
dograh_token_usage = round(
float(cost_info.get("total_cost_usd", 0)) * 100, 2
)
if duration is None and dograh_token_usage == 0:
return None
return {
"dograh_token_usage": dograh_token_usage,
"call_duration_seconds": duration,
}

View file

@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
base_initial_context = dict(workflow_run.initial_context or {})
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=base_initial_context,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = llm
runtime_configuration = {
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
**base_initial_context,
"runtime_configuration": runtime_configuration,
}
if mps_correlation_id:
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)

View file

@ -4,17 +4,11 @@ from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from api.db import db_client
from api.db.models import WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
)
from api.services.workflow.text_chat_logs import (
build_text_chat_realtime_feedback_events,
)
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
state=execution.state,
is_completed=execution.is_completed,
)
workflow_run = await db_client.get_workflow_run_by_id(run_id)
if workflow_run:
try:
# Apply the per-turn delta so org usage tracks cumulative run cost
# without replaying the full session totals on every turn.
await apply_usage_delta_to_organization(workflow_run, execution.usage)
except Exception as e:
logger.error(
f"Failed to update organization usage for text chat run {run_id}: {e}"
)
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is not None:
await db_client.update_workflow_run(run_id, cost_info=cost_info)
return await _reload_text_chat_session(run_id)

View file

@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
tracing_context=None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Create span with parent context
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Add result metadata to span
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
else:
# Tracing is disabled - perform retrieval without tracing
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
@ -220,6 +225,7 @@ async def _perform_retrieval(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
@ -272,11 +278,20 @@ async def _perform_retrieval(
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
results = await embedding_service.search_similar_chunks(

View file

@ -0,0 +1,111 @@
"""Workflow-run billing hooks.
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
For hosted deployments, Dograh reports completed platform usage to MPS.
When a server-minted MPS correlation id exists, MPS uses model-service usage
as the canonical duration. Otherwise Dograh reports the completed run duration.
"""
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.services.managed_model_services import get_mps_correlation_id
from api.services.mps_service_key_client import mps_service_key_client
def _workflow_run_organization_id(workflow_run) -> int | None:
workflow = getattr(workflow_run, "workflow", None)
return getattr(workflow, "organization_id", None)
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
duration = usage_info.get("call_duration_seconds")
try:
duration_seconds = float(duration)
except (TypeError, ValueError):
return None
return duration_seconds if duration_seconds > 0 else None
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
account = await mps_service_key_client.get_billing_account_status(
organization_id=organization_id
)
return bool(account and account.get("billing_mode") == "v2")
async def report_workflow_run_platform_usage(workflow_run) -> None:
"""Report hosted platform usage for a completed workflow run to MPS."""
if DEPLOYMENT_MODE == "oss":
return
if not getattr(workflow_run, "is_completed", False):
return
organization_id = _workflow_run_organization_id(workflow_run)
if organization_id is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no organization_id",
workflow_run.id,
)
return
correlation_id = get_mps_correlation_id(
getattr(workflow_run, "initial_context", None)
)
duration_seconds = (
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
)
if not correlation_id and duration_seconds is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no billable duration",
workflow_run.id,
)
return
try:
if not await _organization_uses_mps_billing_v2(organization_id):
return
result = await mps_service_key_client.report_platform_usage(
organization_id=organization_id,
correlation_id=correlation_id,
duration_seconds=duration_seconds,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": getattr(workflow_run, "workflow_id", None),
"duration_source": (
"mps_correlation" if correlation_id else "dograh_usage_info"
),
},
)
logger.info(
"Reported platform usage for workflow run {} to MPS: {}",
workflow_run.id,
result,
)
except Exception as e:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
"""Load a completed workflow run and report platform usage to MPS."""
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Skipping platform usage report: workflow run {} not found",
workflow_run_id,
)
return
await report_workflow_run_platform_usage(workflow_run)

View file

@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import (
)
from api.tasks.knowledge_base_processing import process_knowledge_base_document
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.tasks.s3_upload import (
process_workflow_completion,
upload_voicemail_audio_to_s3,
)
from api.tasks.s3_upload import upload_voicemail_audio_to_s3
from api.tasks.workflow_completion import process_workflow_completion
class WorkerSettings:

View file

@ -166,18 +166,22 @@ async def process_knowledge_base_document(
user_id=document.created_by,
organization_id=document.organization_id,
)
user_config = resolved_config.effective
if user_config.embeddings:
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
effective_config = resolved_config.effective
if effective_config.embeddings:
embeddings_provider = getattr(
effective_config.embeddings, "provider", None
)
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(
effective_config.embeddings, "endpoint", None
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
logger.info(
f"Using user embeddings config: provider={embeddings_provider}, "

View file

@ -1,13 +1,9 @@
import os
from typing import Optional
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from api.db import db_client
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
from api.services.storage import get_current_storage_backend, storage_fs
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.services.storage import storage_fs
async def upload_voicemail_audio_to_s3(
@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3(
logger.warning(
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
)
async def process_workflow_completion(
_ctx,
workflow_run_id: int,
audio_temp_path: Optional[str] = None,
transcript_temp_path: Optional[str] = None,
):
"""Process workflow completion: upload artifacts and run integrations.
This task combines audio upload, transcript upload, and webhook integrations
into a single sequential task to ensure integrations run after uploads complete.
Args:
_ctx: ARQ context (unused)
workflow_run_id: The workflow run ID
audio_temp_path: Optional path to temp audio file
transcript_temp_path: Optional path to temp transcript file
"""
run_id = str(workflow_run_id)
set_current_run_id(run_id)
logger.info(f"Processing workflow completion for run {workflow_run_id}")
storage_backend = get_current_storage_backend()
# Step 1: Upload audio if provided
if audio_temp_path:
try:
if os.path.exists(audio_temp_path):
file_size = os.path.getsize(audio_temp_path)
logger.debug(f"Audio file size: {file_size} bytes")
recording_url = f"recordings/{workflow_run_id}.wav"
logger.info(
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(audio_temp_path, recording_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
recording_url=recording_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded audio: {recording_url}")
else:
logger.warning(f"Audio temp file not found: {audio_temp_path}")
except Exception as e:
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
finally:
if audio_temp_path and os.path.exists(audio_temp_path):
try:
os.remove(audio_temp_path)
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp audio file: {e}")
# Step 2: Upload transcript if provided
if transcript_temp_path:
try:
if os.path.exists(transcript_temp_path):
file_size = os.path.getsize(transcript_temp_path)
logger.debug(f"Transcript file size: {file_size} bytes")
transcript_url = f"transcripts/{workflow_run_id}.txt"
logger.info(
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
transcript_url=transcript_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded transcript: {transcript_url}")
else:
logger.warning(
f"Transcript temp file not found: {transcript_temp_path}"
)
except Exception as e:
logger.error(
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
)
finally:
if transcript_temp_path and os.path.exists(transcript_temp_path):
try:
os.remove(transcript_temp_path)
logger.debug(
f"Cleaned up temp transcript file: {transcript_temp_path}"
)
except Exception as e:
logger.warning(f"Failed to clean up temp transcript file: {e}")
# Step 3: Run integrations including QA analysis (after uploads are complete)
try:
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
except Exception as e:
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
# Step 4: Calculate cost after integrations (so QA token usage is included)
try:
await calculate_workflow_run_cost(workflow_run_id)
except Exception as e:
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")

View file

@ -0,0 +1,121 @@
import os
from typing import Optional
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from api.db import db_client
from api.services.storage import get_current_storage_backend, storage_fs
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
)
from api.tasks.run_integrations import run_integrations_post_workflow_run
async def process_workflow_completion(
_ctx,
workflow_run_id: int,
audio_temp_path: Optional[str] = None,
transcript_temp_path: Optional[str] = None,
):
"""Process workflow completion: upload artifacts and run integrations.
This task combines audio upload, transcript upload, and webhook integrations
into a single sequential task to ensure integrations run after uploads complete.
Args:
_ctx: ARQ context (unused)
workflow_run_id: The workflow run ID
audio_temp_path: Optional path to temp audio file
transcript_temp_path: Optional path to temp transcript file
"""
run_id = str(workflow_run_id)
set_current_run_id(run_id)
logger.info(f"Processing workflow completion for run {workflow_run_id}")
storage_backend = get_current_storage_backend()
# Step 1: Upload audio if provided
if audio_temp_path:
try:
if os.path.exists(audio_temp_path):
file_size = os.path.getsize(audio_temp_path)
logger.debug(f"Audio file size: {file_size} bytes")
recording_url = f"recordings/{workflow_run_id}.wav"
logger.info(
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(audio_temp_path, recording_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
recording_url=recording_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded audio: {recording_url}")
else:
logger.warning(f"Audio temp file not found: {audio_temp_path}")
except Exception as e:
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
finally:
if audio_temp_path and os.path.exists(audio_temp_path):
try:
os.remove(audio_temp_path)
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp audio file: {e}")
# Step 2: Upload transcript if provided
if transcript_temp_path:
try:
if os.path.exists(transcript_temp_path):
file_size = os.path.getsize(transcript_temp_path)
logger.debug(f"Transcript file size: {file_size} bytes")
transcript_url = f"transcripts/{workflow_run_id}.txt"
logger.info(
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
transcript_url=transcript_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded transcript: {transcript_url}")
else:
logger.warning(
f"Transcript temp file not found: {transcript_temp_path}"
)
except Exception as e:
logger.error(
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
)
finally:
if transcript_temp_path and os.path.exists(transcript_temp_path):
try:
os.remove(transcript_temp_path)
logger.debug(
f"Cleaned up temp transcript file: {transcript_temp_path}"
)
except Exception as e:
logger.warning(f"Failed to clean up temp transcript file: {e}")
# Step 3: Run integrations including QA analysis (after uploads are complete)
try:
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
except Exception as e:
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
# Step 4: Notify MPS after completion. MPS owns credit accounting.
try:
await report_completed_workflow_run_platform_usage(workflow_run_id)
except Exception as e:
logger.error(
f"Error reporting platform usage for workflow {workflow_run_id}: {e}"
)
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")

View file

@ -203,7 +203,7 @@ async def create_workflow_run_rows(
Returns:
Tuple of (workflow_run, user, workflow).
"""
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
async_session.add(org)

View file

@ -1,12 +1,16 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
@ -15,6 +19,7 @@ from api.services.configuration.ai_model_configuration import (
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
@ -22,6 +27,8 @@ from api.services.configuration.registry import (
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
GoogleLLMService,
GoogleRealtimeLLMConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
@ -49,6 +56,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.managed_service_version == 2
def test_dograh_v2_rejects_non_predefined_speed():
@ -92,6 +100,67 @@ def test_byok_v2_rejects_dograh_provider():
)
@pytest.mark.asyncio
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
config = OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "realtime",
"realtime": {
"realtime": {
"provider": "google_realtime",
"api_key": "google-realtime-key",
"model": "gemini-3.1-flash-live-preview",
"voice": "Puck",
"language": "en",
},
"llm": {
"provider": "google",
"api_key": "google-llm-key",
"model": "gemini-2.0-flash",
},
},
},
}
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is True
assert effective.stt is None
assert effective.tts is None
assert await UserConfigurationValidator().validate(effective) == {
"status": [{"model": "all", "message": "ok"}]
}
@pytest.mark.asyncio
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
effective = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key="google-llm-key",
model="gemini-2.0-flash",
),
realtime=GoogleRealtimeLLMConfiguration(
provider="google_realtime",
api_key="google-realtime-key",
model="gemini-3.1-flash-live-preview",
voice="Puck",
language="en",
),
is_realtime=False,
)
with pytest.raises(ValueError) as exc_info:
await UserConfigurationValidator().validate(effective)
assert exc_info.value.args[0] == [
{"model": "stt", "message": "API key is missing"},
{"model": "tts", "message": "API key is missing"},
]
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",
@ -293,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
@pytest.mark.asyncio
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
monkeypatch,
):
from api.routes import organization as organization_routes
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
)
expected_response = OrganizationAIModelConfigurationResponse(
configuration={"version": 2, "mode": "dograh"},
effective_configuration={},
source="organization_v2",
)
class FakeValidator:
async def validate(self, *args, **kwargs):
return {"status": [{"model": "all", "message": "ok"}]}
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
upsert = AsyncMock()
migrate_workflows = AsyncMock()
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_routes,
"get_organization_ai_model_configuration_v2",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
organization_routes.db_client,
"get_user_configurations",
AsyncMock(return_value=legacy),
)
monkeypatch.setattr(
organization_routes,
"UserConfigurationValidator",
lambda: FakeValidator(),
)
monkeypatch.setattr(
organization_routes,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
monkeypatch.setattr(
organization_routes,
"upsert_organization_ai_model_configuration_v2",
upsert,
)
monkeypatch.setattr(
organization_routes,
"migrate_workflow_model_configurations_to_v2",
migrate_workflows,
)
monkeypatch.setattr(
organization_routes,
"_model_configuration_v2_response",
AsyncMock(return_value=expected_response),
)
user = SimpleNamespace(
id=7,
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_routes.migrate_model_configuration_v2(
force=False,
user=user,
)
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
upsert.assert_awaited_once()
migrate_workflows.assert_awaited_once_with(
organization_id=42,
fallback_user_config=legacy,
)
assert response == expected_response

View file

@ -0,0 +1,68 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.auth import depends as auth_depends
@pytest.mark.asyncio
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
stack_user = {
"id": "stack-user-1",
"selected_team_id": "team-1",
"primary_email_verified": False,
}
user = SimpleNamespace(
id=7,
email=None,
provider_id="stack-user-1",
selected_organization_id=None,
)
organization = SimpleNamespace(id=42)
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
monkeypatch.setattr(
auth_depends.stackauth,
"get_user",
AsyncMock(return_value=stack_user),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_user_by_provider_id",
AsyncMock(return_value=(user, False)),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_organization_by_provider_id",
AsyncMock(return_value=(organization, True)),
)
monkeypatch.setattr(
auth_depends.db_client,
"add_user_to_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"update_user_selected_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_user_configurations",
AsyncMock(return_value=existing_config),
)
monkeypatch.setattr(
auth_depends,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
result = await auth_depends.get_user(authorization="Bearer token")
assert result is user
assert result.selected_organization_id == 42
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")

View file

@ -1,31 +0,0 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -0,0 +1,110 @@
import json
import pytest
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.frames.frames import TTSStartedFrame
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.openai.base_llm import OpenAILLMSettings
from websockets.protocol import State
class _FakeWebSocket:
def __init__(self):
self.state = State.OPEN
self.messages: list[dict] = []
async def send(self, message: str) -> None:
self.messages.append(json.loads(message))
async def close(self, *args, **kwargs) -> None:
self.state = State.CLOSED
def test_dograh_llm_uses_explicit_mps_correlation_id():
service = DograhLLMService(
api_key="mps-secret",
correlation_id="mps-corr-123",
settings=OpenAILLMSettings(model="default"),
)
service._start_metadata = {"workflow_run_id": 99}
params = service.build_chat_completion_params(
{
"messages": [],
"tools": OPENAI_NOT_GIVEN,
"tool_choice": OPENAI_NOT_GIVEN,
}
)
assert params["metadata"]["correlation_id"] == "mps-corr-123"
assert params["metadata"]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.stt.websocket_connect",
fake_connect,
)
service = DograhSTTService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=16000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.tts.websocket_connect",
fake_connect,
)
service = DograhTTSService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=24000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
async def _noop(*args, **kwargs):
return None
service.audio_context_available = lambda context_id: False
service.create_audio_context = _noop
service.start_ttfb_metrics = _noop
service.start_tts_usage_metrics = _noop
frames = []
async for frame in service.run_tts("hello", "ctx-1"):
frames.append(frame)
assert isinstance(frames[0], TTSStartedFrame)
assert fake_ws.messages[1]["type"] == "create_context"
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[1]["mps_billing_version"] == "2"

View file

@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
def test_factory_creates_dograh_grok_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -5,7 +5,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes.user import router
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (

View file

@ -87,3 +87,317 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
"Content-Type": "application/json",
},
)
@pytest.mark.asyncio
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"correlation_id": "mps-corr-123"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
client = MPSServiceKeyClient()
assert await client.create_correlation_id(
service_key="mps_sk_paid",
workflow_run_id=42,
) == {"correlation_id": "mps-corr-123"}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
{"workflow_run_id": 42},
{
"Authorization": "Bearer mps_sk_paid",
"Content-Type": "application/json",
},
)
]
@pytest.mark.asyncio
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_billing_account_status(organization_id=42) == {
"organization_id": 42,
"billing_mode": "v2",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/status",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(
200,
{
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.ensure_billing_account_v2(
organization_id=42,
created_by="provider-123",
) == {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/balance",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, params, headers):
calls.append(("GET", url, params, headers))
return _Response(
200,
{
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_credit_ledger(
organization_id=42,
page=3,
limit=25,
) == {
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
{"page": 3, "limit": 25},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
correlation_id="mps-corr-123",
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"correlation_id": "mps-corr-123",
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
duration_seconds=87.0,
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"duration_seconds": 87.0,
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]

View file

@ -0,0 +1,99 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.routes import organization_usage
def test_is_mps_billing_v2_depends_only_on_account_mode():
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
assert organization_usage._is_mps_billing_v2(None) is False
@pytest.mark.asyncio
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
get_status = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_billing_account_status",
get_status,
)
user = SimpleNamespace(provider_id="provider-123")
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
"billing_mode": "v2"
}
get_status.assert_awaited_once_with(
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_usage,
"_get_mps_billing_account_status",
AsyncMock(return_value={"billing_mode": "v2"}),
)
get_ledger = AsyncMock(
return_value={
"account": {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": 250,
"currency": "USD",
},
"ledger_entries": [
{
"id": 99,
"entry_type": "grant",
"origin": "account_creation",
"credits_delta": 250,
"balance_after": 250,
"created_at": "2026-06-12T00:00:00Z",
}
],
"total_debits_credits": 75,
"total_count": 101,
"page": 3,
"limit": 25,
"total_pages": 5,
}
)
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_credit_ledger",
get_ledger,
)
user = SimpleNamespace(
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_usage.get_billing_credits(
page=3,
limit=25,
user=user,
)
get_ledger.assert_awaited_once_with(
organization_id=42,
page=3,
limit=25,
created_by="provider-123",
)
assert response.billing_version == "v2"
assert response.total_credits_used == 75
assert response.total_count == 101
assert response.page == 3
assert response.limit == 25
assert response.total_pages == 5
assert response.ledger_entries[0].id == 99

View file

@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
import pytest
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
contains_masked_key,
mask_workflow_configurations,

View file

@ -1,4 +1,4 @@
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.workflow.run_usage_response import format_public_usage_info
def test_format_public_usage_info():

View file

@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
from websockets.exceptions import ConnectionClosedError
from websockets.frames import Close
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
from api.services.pipecat.realtime.ultravox_realtime import (
_RESUMPTION_USER_MESSAGE,
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
def test_factory_creates_dograh_ultravox_realtime_service():
user_config = EffectiveAIModelConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=UltravoxRealtimeLLMConfiguration(
provider="ultravox_realtime",
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -0,0 +1,212 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()

View file

@ -1,181 +0,0 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
calculate_workflow_run_cost,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
mode="textchat",
created_at=datetime.now(UTC),
usage_info={
"llm": {},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
},
cost_info={},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
cost_info = await build_workflow_run_cost_info(workflow_run)
assert cost_info is not None
assert cost_info["call_duration_seconds"] == 7
assert "cost_breakdown" in cost_info
assert "dograh_token_usage" in cost_info
assert cost_info["charge_usd"] == 10.5
update_usage.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
update_run = AsyncMock()
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
await calculate_workflow_run_cost(workflow_run.id)
update_run.assert_awaited_once()
saved_kwargs = update_run.await_args.kwargs
assert saved_kwargs["run_id"] == workflow_run.id
assert "cost_breakdown" in saved_kwargs["cost_info"]
update_usage.assert_awaited_once()
@pytest.mark.asyncio
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.cost_info = {"call_id": "preserve-me"}
usage_delta_one = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 1_000,
"completion_tokens": 100,
"total_tokens": 1_100,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 3,
}
usage_delta_two = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 2_000,
"completion_tokens": 50,
"total_tokens": 2_050,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 4,
}
merged_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 3_000,
"completion_tokens": 150,
"total_tokens": 3_150,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {},
"stt": {},
"call_duration_seconds": 7,
}
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
update_usage = AsyncMock()
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
)
monkeypatch.setattr(
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
)
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
second_delta = await apply_usage_delta_to_organization(
workflow_run, usage_delta_two
)
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
total_workflow_run.usage_info = merged_usage
total_cost = await build_workflow_run_cost_info(total_workflow_run)
assert first_delta is not None
assert second_delta is not None
assert total_cost is not None
assert update_usage.await_count == 2
assert update_usage.await_args_list[0].args == (
42,
first_delta["dograh_token_usage"],
3.0,
first_delta["charge_usd"],
)
assert update_usage.await_args_list[1].args == (
42,
second_delta["dograh_token_usage"],
4.0,
second_delta["charge_usd"],
)
assert (
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
) == pytest.approx(total_cost["dograh_token_usage"])
assert (
first_delta["charge_usd"] + second_delta["charge_usd"]
== total_cost["charge_usd"]
)

View file

@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from api.db.models import OrganizationModel, UserModel
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
from pipecat.tests import MockLLMService
@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Hello from the workflow tester."
]
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
assert workflow_run is not None
assert workflow_run.cost_info[
"call_duration_seconds"
] == workflow_run.usage_info.get("call_duration_seconds", 0)
assert "cost_breakdown" in workflow_run.cost_info
assert "dograh_token_usage" in workflow_run.cost_info
assert "call_duration_seconds" in workflow_run.usage_info
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
"Welcome to the workflow tester.",

@ -1 +1 @@
Subproject commit 228324a146a6765c6b8d610963bc80d7bc8cb9f7
Subproject commit 0d64dc6e0e3e6b3c46cc66373e34b4f54f980268

416
ui/src/app/billing/page.tsx Normal file
View file

@ -0,0 +1,416 @@
"use client";
import {
ChevronLeft,
ChevronRight,
CircleDollarSign,
CreditCard,
RefreshCw,
} from "lucide-react";
import Link from "next/link";
import { useRouter, useSearchParams } from "next/navigation";
import { useCallback, useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost, getBillingCreditsApiV1OrganizationsBillingCreditsGet } from "@/client/sdk.gen";
import type { MpsBillingCreditsResponse, MpsCreditLedgerEntryResponse } from "@/client/types.gen";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Progress } from "@/components/ui/progress";
import { Skeleton } from "@/components/ui/skeleton";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { useAppConfig } from "@/context/AppConfigContext";
import { useAuth } from "@/lib/auth";
const LEDGER_PAGE_SIZE = 50;
const formatCredits = (value: number | null | undefined) => (
(value ?? 0).toLocaleString(undefined, {
maximumFractionDigits: 2,
minimumFractionDigits: 0,
})
);
const formatAmount = (amountMinor?: number | null, currency?: string | null) => {
if (amountMinor == null) {
return "-";
}
return new Intl.NumberFormat(undefined, {
style: "currency",
currency: currency || "USD",
}).format(amountMinor / 100);
};
const formatDate = (value: string) => (
new Date(value).toLocaleString(undefined, {
month: "short",
day: "numeric",
year: "numeric",
hour: "2-digit",
minute: "2-digit",
})
);
const metricLabels: Record<string, string> = {
voice_minutes: "Voice usage",
platform_usage: "Platform usage",
};
const formatTitleCase = (value: string | null | undefined) => (
value ? value.replaceAll("_", " ").replace(/\b\w/g, (letter) => letter.toUpperCase()) : "-"
);
const getLedgerEntryLabel = (entry: MpsCreditLedgerEntryResponse) => {
if (entry.metric_code) {
return metricLabels[entry.metric_code] ?? formatTitleCase(entry.metric_code);
}
if (entry.entry_type === "grant") {
return "Credit grant";
}
if (entry.entry_type === "purchase") {
return "Credit purchase";
}
return formatTitleCase(entry.entry_type);
};
const formatBillableQuantity = (entry: MpsCreditLedgerEntryResponse) => {
if (entry.billable_quantity == null || !entry.quantity_unit) {
return null;
}
const unit = entry.quantity_unit === "minute" ? "min" : entry.quantity_unit;
return `${formatCredits(entry.billable_quantity)} ${unit}`;
};
const getRunHref = (entry: MpsCreditLedgerEntryResponse) => {
if (!entry.workflow_id || !entry.workflow_run_id) {
return null;
}
return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`;
};
const getPageFromSearchParams = (
searchParams: { get: (name: string) => string | null },
) => {
const pageParam = searchParams.get("page");
const page = pageParam ? Number.parseInt(pageParam, 10) : 1;
return Number.isFinite(page) && page > 0 ? page : 1;
};
export default function BillingPage() {
const router = useRouter();
const searchParams = useSearchParams();
const auth = useAuth();
const { config } = useAppConfig();
const [credits, setCredits] = useState<MpsBillingCreditsResponse | null>(null);
const [loading, setLoading] = useState(true);
const [refreshing, setRefreshing] = useState(false);
const [purchasing, setPurchasing] = useState(false);
const [currentPage, setCurrentPage] = useState(
() => getPageFromSearchParams(searchParams),
);
const isBillingV2 = credits?.billing_version === "v2";
const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss";
const totalQuota = credits?.total_quota ?? 0;
const remainingCredits = credits?.remaining_credits ?? 0;
const usedCredits = credits?.total_credits_used ?? 0;
const usagePercent = totalQuota > 0 ? Math.min(100, Math.round((usedCredits / totalQuota) * 100)) : 0;
const ledgerEntries = useMemo(() => credits?.ledger_entries ?? [], [credits?.ledger_entries]);
const ledgerPage = credits?.page ?? currentPage;
const ledgerTotalCount = credits?.total_count ?? ledgerEntries.length;
const ledgerTotalPages = credits?.total_pages ?? 0;
const fetchCredits = useCallback(async (
page: number,
{ silent = false }: { silent?: boolean } = {},
) => {
if (auth.loading) {
return;
}
if (!auth.isAuthenticated) {
setLoading(false);
return;
}
if (silent) {
setRefreshing(true);
} else {
setLoading(true);
}
try {
const response = await getBillingCreditsApiV1OrganizationsBillingCreditsGet({
query: { page, limit: LEDGER_PAGE_SIZE },
});
if (response.error) {
throw new Error("Failed to fetch billing credits");
}
setCredits(response.data ?? null);
} catch (error) {
console.error("Failed to fetch billing credits:", error);
toast.error("Failed to fetch billing credits");
} finally {
setLoading(false);
setRefreshing(false);
}
}, [auth.isAuthenticated, auth.loading]);
useEffect(() => {
const nextPage = getPageFromSearchParams(searchParams);
setCurrentPage((previousPage) => (
previousPage === nextPage ? previousPage : nextPage
));
}, [searchParams]);
useEffect(() => {
fetchCredits(currentPage);
}, [currentPage, fetchCredits]);
const handleRefresh = () => {
fetchCredits(currentPage, { silent: true });
};
const updateUrlPage = useCallback((page: number) => {
const newParams = new URLSearchParams(searchParams.toString());
if (page > 1) {
newParams.set("page", page.toString());
} else {
newParams.delete("page");
}
const queryString = newParams.toString();
router.push(queryString ? `/billing?${queryString}` : "/billing");
}, [router, searchParams]);
const handlePageChange = (page: number) => {
const nextPage = Math.max(1, page);
setCurrentPage(nextPage);
updateUrlPage(nextPage);
};
const handlePurchaseCredits = async () => {
if (!canPurchaseCredits) {
return;
}
setPurchasing(true);
try {
const response = await createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost();
const checkoutUrl = response.data?.checkout_url;
if (!checkoutUrl) {
throw new Error("Missing checkout URL");
}
window.location.href = checkoutUrl;
} catch (error) {
console.error("Failed to create credit purchase URL:", error);
toast.error("Failed to open checkout");
setPurchasing(false);
}
};
if (loading) {
return (
<div className="container mx-auto p-6 space-y-6">
<div className="space-y-2">
<Skeleton className="h-9 w-40" />
<Skeleton className="h-5 w-96 max-w-full" />
</div>
<div className="grid gap-4 md:grid-cols-2">
<Skeleton className="h-36 rounded-lg" />
<Skeleton className="h-36 rounded-lg" />
</div>
<Skeleton className="h-80 rounded-lg" />
</div>
);
}
return (
<div className="container mx-auto p-6 space-y-6">
<div className="flex flex-col gap-4 md:flex-row md:items-start md:justify-between">
<div>
<h1 className="text-3xl font-bold mb-2">Billing</h1>
<p className="text-muted-foreground">
Credits, balance, and account usage for your organization.
</p>
</div>
<div className="flex items-center gap-2">
<Button variant="outline" onClick={handleRefresh} disabled={refreshing}>
<RefreshCw className={`h-4 w-4 mr-2 ${refreshing ? "animate-spin" : ""}`} />
Refresh
</Button>
{canPurchaseCredits && (
<Button onClick={handlePurchaseCredits} disabled={purchasing}>
<CreditCard className="h-4 w-4 mr-2" />
{purchasing ? "Opening..." : "Add Credits"}
</Button>
)}
</div>
</div>
<div className="grid gap-4 md:grid-cols-2">
<Card>
<CardHeader className="pb-2">
<CardDescription>{isBillingV2 ? "Credit balance" : "Credits remaining"}</CardDescription>
<CardTitle className="flex items-center gap-2 text-3xl">
<CircleDollarSign className="h-6 w-6 text-muted-foreground" />
{formatCredits(remainingCredits)}
</CardTitle>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">1 credit = 1 cent</p>
</CardContent>
</Card>
<Card>
<CardHeader className="pb-2">
<CardDescription>Credits used</CardDescription>
<CardTitle className="text-3xl">{formatCredits(usedCredits)}</CardTitle>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">
{isBillingV2 ? "Total ledger debits" : "Current allocation usage"}
</p>
</CardContent>
</Card>
</div>
{isBillingV2 ? (
<Card>
<CardHeader>
<CardTitle>Credit Ledger</CardTitle>
<CardDescription>Recent grants, purchases, and usage debits.</CardDescription>
</CardHeader>
<CardContent>
{ledgerEntries.length > 0 ? (
<div className="bg-card border rounded-lg overflow-x-auto shadow-sm">
<Table>
<TableHeader>
<TableRow className="bg-muted/50">
<TableHead>Date</TableHead>
<TableHead>Activity</TableHead>
<TableHead>Origin</TableHead>
<TableHead>Run</TableHead>
<TableHead className="text-right">Delta</TableHead>
<TableHead className="text-right">Balance</TableHead>
<TableHead className="text-right">Amount</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{ledgerEntries.map((entry) => {
const delta = entry.credits_delta ?? 0;
const runHref = getRunHref(entry);
const billableQuantity = formatBillableQuantity(entry);
return (
<TableRow key={entry.id}>
<TableCell>{formatDate(entry.created_at)}</TableCell>
<TableCell>
<div className="flex flex-col gap-1">
<span className="font-medium">{getLedgerEntryLabel(entry)}</span>
{billableQuantity && (
<span className="text-xs text-muted-foreground">{billableQuantity}</span>
)}
</div>
</TableCell>
<TableCell>
{entry.origin ? (
<Badge variant="secondary">{formatTitleCase(entry.origin)}</Badge>
) : (
"-"
)}
</TableCell>
<TableCell>
{entry.workflow_run_id ? (
runHref ? (
<Link className="font-medium text-primary hover:underline" href={runHref}>
#{entry.workflow_run_id}
</Link>
) : (
<span>#{entry.workflow_run_id}</span>
)
) : (
"-"
)}
</TableCell>
<TableCell className={`text-right font-medium ${delta >= 0 ? "text-green-600" : "text-destructive"}`}>
{delta >= 0 ? "+" : ""}
{formatCredits(delta)}
</TableCell>
<TableCell className="text-right">{formatCredits(entry.balance_after)}</TableCell>
<TableCell className="text-right">
{formatAmount(entry.amount_minor, entry.amount_currency)}
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
) : (
<div className="rounded-lg border border-dashed p-8 text-center text-muted-foreground">
No ledger entries yet
</div>
)}
{ledgerTotalPages > 1 && (
<div className="flex items-center justify-between mt-6">
<p className="text-sm text-muted-foreground">
Page {ledgerPage} of {ledgerTotalPages} ({ledgerTotalCount} total entries)
</p>
<div className="flex gap-2">
<Button
variant="outline"
size="sm"
onClick={() => handlePageChange(ledgerPage - 1)}
disabled={ledgerPage <= 1 || loading || refreshing}
>
<ChevronLeft className="h-4 w-4" />
Previous
</Button>
<Button
variant="outline"
size="sm"
onClick={() => handlePageChange(ledgerPage + 1)}
disabled={ledgerPage >= ledgerTotalPages || loading || refreshing}
>
Next
<ChevronRight className="h-4 w-4" />
</Button>
</div>
</div>
)}
</CardContent>
</Card>
) : (
<Card>
<CardHeader>
<CardTitle>Credit Usage</CardTitle>
</CardHeader>
<CardContent className="space-y-4">
<Progress value={usagePercent} />
<div className="flex justify-between text-sm text-muted-foreground">
<span>{usagePercent}% used</span>
<span>{formatCredits(remainingCredits)} of {formatCredits(totalQuota)} remaining</span>
</div>
</CardContent>
</Card>
)}
</div>
);
}

View file

@ -12,8 +12,8 @@ import SpinLoader from "@/components/SpinLoader";
import { Toaster } from "@/components/ui/sonner";
import { AppConfigProvider } from "@/context/AppConfigContext";
import { OnboardingProvider } from "@/context/OnboardingContext";
import { OrgConfigProvider } from "@/context/OrgConfigContext";
import { TelephonyConfigWarningsProvider } from "@/context/TelephonyConfigWarningsContext";
import { UserConfigProvider } from "@/context/UserConfigContext";
import { AuthProvider } from "@/lib/auth";
@ -65,7 +65,7 @@ export default function RootLayout({
<AuthProvider>
<AppConfigProvider>
<Suspense fallback={<SpinLoader />}>
<UserConfigProvider>
<OrgConfigProvider>
<TelephonyConfigWarningsProvider>
<OnboardingProvider>
<PostHogIdentify />
@ -76,7 +76,7 @@ export default function RootLayout({
<ChatwootWidget />
</OnboardingProvider>
</TelephonyConfigWarningsProvider>
</UserConfigProvider>
</OrgConfigProvider>
</Suspense>
</AppConfigProvider>
</AuthProvider>

View file

@ -2,7 +2,7 @@
import { addDays, format, subDays } from 'date-fns';
import { Calendar, ChevronLeft, ChevronRight, Download } from 'lucide-react';
import { useEffect,useState } from 'react';
import { useEffect, useState } from 'react';
import {
getDailyReportApiV1OrganizationsReportsDailyGet,
@ -201,7 +201,9 @@ export default function ReportsPage() {
<div className="container mx-auto p-6 space-y-6">
{/* Header */}
<div className="flex flex-col sm:flex-row justify-between items-start sm:items-center gap-4">
<h1 className="text-3xl font-bold">Daily Reports</h1>
<div className="space-y-2">
<h1 className="text-3xl font-bold">Daily Reports</h1>
</div>
{/* Date Navigation & Workflow Selector */}
<div className="flex flex-col sm:flex-row gap-4 items-start sm:items-center">

View file

@ -6,8 +6,8 @@ import { useCallback, useEffect, useId, useState } from 'react';
import TimezoneSelect, { type ITimezoneOption } from 'react-timezone-select';
import { toast } from 'sonner';
import { downloadUsageRunsReportApiV1OrganizationsUsageRunsReportGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet, getPreferencesApiV1OrganizationsPreferencesGet, getUsageHistoryApiV1OrganizationsUsageRunsGet, savePreferencesApiV1OrganizationsPreferencesPut } from '@/client/sdk.gen';
import type { DailyUsageBreakdownResponse, MpsCreditsResponse, OrganizationPreferences, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
import { downloadUsageRunsReportApiV1OrganizationsUsageRunsReportGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getPreferencesApiV1OrganizationsPreferencesGet, getUsageHistoryApiV1OrganizationsUsageRunsGet, savePreferencesApiV1OrganizationsPreferencesPut } from '@/client/sdk.gen';
import type { DailyUsageBreakdownResponse, OrganizationPreferences, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
import { CallTypeCell } from '@/components/CallTypeCell';
import { DailyUsageTable } from '@/components/DailyUsageTable';
import { FilterBuilder } from '@/components/filters/FilterBuilder';
@ -15,7 +15,6 @@ import { MediaPreviewButton, MediaPreviewDialog } from '@/components/MediaPrevie
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Progress } from '@/components/ui/progress';
import {
Table,
TableBody,
@ -39,10 +38,6 @@ export default function UsagePage() {
const { organizationPricing } = useUserConfig();
const auth = useAuth();
// MPS credits state
const [mpsCredits, setMpsCredits] = useState<MpsCreditsResponse | null>(null);
const [isLoadingCredits, setIsLoadingCredits] = useState(true);
// Usage history state
const [usageHistory, setUsageHistory] = useState<UsageHistoryResponse | null>(null);
const [isLoadingHistory, setIsLoadingHistory] = useState(false);
@ -78,21 +73,6 @@ export default function UsagePage() {
const [preferencesLoading, setPreferencesLoading] = useState(true);
const timezoneSelectId = useId(); // Stable ID for react-select to prevent hydration mismatch
// Fetch MPS credits
const fetchMpsCredits = useCallback(async () => {
if (!auth.isAuthenticated) return;
try {
const response = await getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet();
if (response.data) {
setMpsCredits(response.data);
}
} catch (error) {
console.error('Failed to fetch MPS credits:', error);
} finally {
setIsLoadingCredits(false);
}
}, [auth.isAuthenticated]);
// Translate the FilterBuilder state into the query-param shape the
// backend expects. Shared between the listing fetch and the CSV export
// so they stay in lockstep.
@ -251,10 +231,9 @@ export default function UsagePage() {
// Initial load - fetch when auth becomes available
useEffect(() => {
if (auth.isAuthenticated) {
fetchMpsCredits();
fetchUsageHistory(currentPage, appliedFilters);
}
}, [auth.isAuthenticated, currentPage, appliedFilters, fetchUsageHistory, fetchMpsCredits]);
}, [auth.isAuthenticated, currentPage, appliedFilters, fetchUsageHistory]);
// Fetch daily usage when organizationPricing becomes available
useEffect(() => {
@ -428,46 +407,6 @@ export default function UsagePage() {
</div>
</div>
{/* MPS Credits Card */}
<Card className="mb-6">
<CardHeader>
<CardTitle>Dograh Model Credits</CardTitle>
<CardDescription>
These track usage of Dograh models using Dograh Service Keys.
</CardDescription>
</CardHeader>
<CardContent>
{isLoadingCredits ? (
<div className="animate-pulse space-y-4">
<div className="h-4 bg-muted rounded w-1/4"></div>
<div className="h-8 bg-muted rounded"></div>
<div className="h-4 bg-muted rounded w-1/3"></div>
</div>
) : mpsCredits ? (
<div className="space-y-4">
<div className="flex justify-between items-baseline">
<div>
<p className="text-2xl font-bold">
{mpsCredits.total_credits_used.toFixed(2)} <span className="text-lg font-normal text-muted-foreground">/ {mpsCredits.total_quota.toFixed(2)}</span>
</p>
<p className="text-sm text-muted-foreground">Credits Used</p>
</div>
<div className="text-right">
<p className="text-lg font-semibold">{mpsCredits.remaining_credits.toFixed(2)}</p>
<p className="text-sm text-muted-foreground">Remaining</p>
</div>
</div>
{mpsCredits.total_quota > 0 && (
<Progress value={(mpsCredits.total_credits_used / mpsCredits.total_quota) * 100} className="h-3" />
)}
</div>
) : (
<p className="text-muted-foreground">No Dograh service keys configured. Set up a service key in your model configuration to see usage.</p>
)}
</CardContent>
</Card>
{/* Daily Usage Table - Only for paid organizations */}
{organizationPricing?.price_per_second_usd && (
<div className="mb-6">
@ -535,9 +474,9 @@ export default function UsagePage() {
<TableHead className="font-semibold">Disposition</TableHead>
<TableHead className="font-semibold">Date</TableHead>
<TableHead className="font-semibold text-right">Duration</TableHead>
<TableHead className="font-semibold text-right">
{organizationPricing?.price_per_second_usd ? 'Cost (USD)' : 'Tokens'}
</TableHead>
{organizationPricing?.price_per_second_usd && (
<TableHead className="font-semibold text-right">Cost (USD)</TableHead>
)}
<TableHead className="font-semibold">Actions</TableHead>
</TableRow>
</TableHeader>
@ -574,12 +513,14 @@ export default function UsagePage() {
<TableCell className="text-right">
{formatDuration(run.call_duration_seconds)}
</TableCell>
<TableCell className="text-right font-medium">
{organizationPricing?.price_per_second_usd && run.charge_usd !== undefined && run.charge_usd !== null
? `$${run.charge_usd.toFixed(2)}`
: run.dograh_token_usage.toLocaleString()
}
</TableCell>
{organizationPricing?.price_per_second_usd && (
<TableCell className="text-right font-medium">
{run.charge_usd !== undefined && run.charge_usd !== null
? `$${run.charge_usd.toFixed(2)}`
: '-'
}
</TableCell>
)}
<TableCell>
<MediaPreviewButton
recordingUrl={run.recording_url}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1642,22 +1642,6 @@ export type CurrentUsageResponse = {
* Used Dograh Tokens
*/
used_dograh_tokens: number;
/**
* Quota Dograh Tokens
*/
quota_dograh_tokens: number;
/**
* Percentage Used
*/
percentage_used: number;
/**
* Next Refresh Date
*/
next_refresh_date: string;
/**
* Quota Enabled
*/
quota_enabled: boolean;
/**
* Total Duration Seconds
*/
@ -1666,10 +1650,6 @@ export type CurrentUsageResponse = {
* Used Amount Usd
*/
used_amount_usd?: number | null;
/**
* Quota Amount Usd
*/
quota_amount_usd?: number | null;
/**
* Currency
*/
@ -3107,6 +3087,165 @@ export type LoginRequest = {
password: string;
};
/**
* MPSBillingAccountResponse
*/
export type MpsBillingAccountResponse = {
/**
* Id
*/
id: number;
/**
* Organization Id
*/
organization_id: number;
/**
* Billing Mode
*/
billing_mode: string;
/**
* Cached Balance Credits
*/
cached_balance_credits: number;
/**
* Currency
*/
currency: string;
};
/**
* MPSBillingCreditsResponse
*/
export type MpsBillingCreditsResponse = {
/**
* Billing Version
*/
billing_version: 'legacy' | 'v2';
/**
* Total Credits Used
*/
total_credits_used?: number;
/**
* Remaining Credits
*/
remaining_credits?: number;
/**
* Total Quota
*/
total_quota?: number;
account?: MpsBillingAccountResponse | null;
/**
* Ledger Entries
*/
ledger_entries?: Array<MpsCreditLedgerEntryResponse>;
/**
* Total Count
*/
total_count?: number;
/**
* Page
*/
page?: number;
/**
* Limit
*/
limit?: number;
/**
* Total Pages
*/
total_pages?: number;
};
/**
* MPSCreditLedgerEntryResponse
*/
export type MpsCreditLedgerEntryResponse = {
/**
* Id
*/
id: number;
/**
* Entry Type
*/
entry_type: string;
/**
* Origin
*/
origin?: string | null;
/**
* Credits Delta
*/
credits_delta: number;
/**
* Balance After
*/
balance_after: number;
/**
* Amount Minor
*/
amount_minor?: number | null;
/**
* Amount Currency
*/
amount_currency?: string | null;
/**
* Payment Order Id
*/
payment_order_id?: number | null;
/**
* Metric Code
*/
metric_code?: string | null;
/**
* Correlation Id
*/
correlation_id?: string | null;
/**
* Aggregation Key
*/
aggregation_key?: string | null;
/**
* Usage Event Id
*/
usage_event_id?: number | null;
/**
* Workflow Run Id
*/
workflow_run_id?: number | null;
/**
* Workflow Id
*/
workflow_id?: number | null;
/**
* Billable Quantity
*/
billable_quantity?: number | null;
/**
* Quantity Unit
*/
quantity_unit?: string | null;
/**
* Metadata
*/
metadata?: {
[key: string]: unknown;
};
/**
* Created At
*/
created_at: string;
};
/**
* MPSCreditPurchaseUrlResponse
*/
export type MpsCreditPurchaseUrlResponse = {
/**
* Checkout Url
*/
checkout_url: string;
};
/**
* MPSCreditsResponse
*/
@ -3618,6 +3757,43 @@ export type OrganizationAiModelConfigurationV2 = {
byok?: ByokaiModelConfiguration | null;
};
/**
* OrganizationContextResponse
*/
export type OrganizationContextResponse = {
/**
* Organization Id
*/
organization_id?: number | null;
/**
* Organization Provider Id
*/
organization_provider_id?: string | null;
model_services: OrganizationModelServicesContext;
};
/**
* OrganizationModelServicesContext
*/
export type OrganizationModelServicesContext = {
/**
* Config Source
*/
config_source: 'organization_v2' | 'legacy_user_v1' | 'empty';
/**
* Has Model Configuration V2
*/
has_model_configuration_v2: boolean;
/**
* Managed Service Version
*/
managed_service_version?: number | null;
/**
* Uses Managed Service V2
*/
uses_managed_service_v2: boolean;
};
/**
* OrganizationPreferences
*/
@ -9750,6 +9926,45 @@ export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses = {
export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponse = UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses[keyof UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses];
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetData = {
body?: never;
headers?: {
/**
* Authorization
*/
authorization?: string | null;
/**
* X-Api-Key
*/
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/organizations/context';
};
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetError = GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors[keyof GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors];
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses = {
/**
* Successful Response
*/
200: OrganizationContextResponse;
};
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetResponse = GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses[keyof GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses];
export type GetTelephonyProvidersMetadataApiV1OrganizationsTelephonyProvidersMetadataGetData = {
body?: never;
headers?: {
@ -11269,6 +11484,93 @@ export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses = {
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponse = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses];
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetData = {
body?: never;
headers?: {
/**
* Authorization
*/
authorization?: string | null;
/**
* X-Api-Key
*/
'X-API-Key'?: string | null;
};
path?: never;
query?: {
/**
* Page
*/
page?: number;
/**
* Limit
*/
limit?: number;
};
url: '/api/v1/organizations/billing/credits';
};
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetError = GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors[keyof GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors];
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses = {
/**
* Successful Response
*/
200: MpsBillingCreditsResponse;
};
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponse = GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses[keyof GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses];
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostData = {
body?: never;
headers?: {
/**
* Authorization
*/
authorization?: string | null;
/**
* X-Api-Key
*/
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/organizations/usage/mps-credits/purchase-url';
};
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostError = CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors[keyof CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors];
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses = {
/**
* Successful Response
*/
200: MpsCreditPurchaseUrlResponse;
};
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponse = CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses[keyof CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses];
export type GetUsageHistoryApiV1OrganizationsUsageRunsGetData = {
body?: never;
headers?: {

View file

@ -17,7 +17,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
type ModelMode = "dograh" | "byok";
type ModelMode = "realtime" | "dograh" | "byok";
interface DograhDefaults {
voices: string[];
@ -125,24 +125,35 @@ function effectiveConfigToLegacyShape(config: Record<string, unknown> | null): R
};
}
function emptyByokInitialConfig(): Record<string, unknown> {
function emptyByokInitialConfig(isRealtime: boolean): Record<string, unknown> {
return {
is_realtime: false,
is_realtime: isRealtime,
};
}
// The v2 editor surfaces realtime ("Speech to Speech") and pipeline (BYOK) as
// separate tabs, so each tab gets its own initial config. A tab is pre-filled
// only when the saved (or effective) configuration matches that tab's mode;
// otherwise it starts empty so the other tab's data does not leak across.
function getByokInitialConfig(
configuration: Record<string, unknown> | null,
effectiveConfiguration: Record<string, unknown> | null,
wantRealtime: boolean,
): Record<string, unknown> {
const byokConfiguration = byokConfigToLegacyShape(configuration);
if (byokConfiguration) return byokConfiguration;
const matchesTab = (config: Record<string, unknown> | null) =>
config ? Boolean(config.is_realtime) === wantRealtime : false;
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
return emptyByokInitialConfig();
const byokConfiguration = byokConfigToLegacyShape(configuration);
if (byokConfiguration) {
return matchesTab(byokConfiguration) ? byokConfiguration : emptyByokInitialConfig(wantRealtime);
}
return effectiveConfigToLegacyShape(effectiveConfiguration) || emptyByokInitialConfig();
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
return emptyByokInitialConfig(wantRealtime);
}
const effective = effectiveConfigToLegacyShape(effectiveConfiguration);
return matchesTab(effective) ? (effective as Record<string, unknown>) : emptyByokInitialConfig(wantRealtime);
}
function buildDograhState(
@ -185,10 +196,12 @@ function preferredMode(
configuration: Record<string, unknown> | null,
effectiveConfiguration: Record<string, unknown> | null,
): ModelMode {
if (configuration?.mode === "dograh" || configuration?.mode === "byok") {
return configuration.mode;
if (configuration?.mode === "dograh") return "dograh";
if (configuration?.mode === "byok") {
return asRecord(configuration.byok)?.mode === "realtime" ? "realtime" : "byok";
}
return isDograhEffectiveConfig(effectiveConfiguration) ? "dograh" : "byok";
if (isDograhEffectiveConfig(effectiveConfiguration)) return "dograh";
return Boolean(effectiveConfiguration?.is_realtime) ? "realtime" : "byok";
}
function hasRequiredApiKey(
@ -249,7 +262,8 @@ export function AIModelConfigurationV2Editor({
speed: defaults.dograh.defaults.speed,
language: defaults.dograh.defaults.language,
}));
const [byokInitialConfig, setByokInitialConfig] = useState<Record<string, unknown> | null>(null);
const [realtimeInitialConfig, setRealtimeInitialConfig] = useState<Record<string, unknown> | null>(null);
const [pipelineInitialConfig, setPipelineInitialConfig] = useState<Record<string, unknown> | null>(null);
const [isSavingDograh, setIsSavingDograh] = useState(false);
const [error, setError] = useState<string | null>(null);
@ -258,7 +272,8 @@ export function AIModelConfigurationV2Editor({
const rawEffectiveConfiguration = asRecord(effectiveConfiguration);
setMode(preferredMode(rawConfiguration, rawEffectiveConfiguration));
setDograh(buildDograhState(defaults, rawConfiguration, rawEffectiveConfiguration));
setByokInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration));
setRealtimeInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, true));
setPipelineInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, false));
}, [configuration, defaults, effectiveConfiguration]);
const saveDograhConfiguration = async () => {
@ -322,28 +337,30 @@ export function AIModelConfigurationV2Editor({
)}
<Tabs value={mode} onValueChange={(value) => setMode(value as ModelMode)} className="space-y-6">
<TabsList className="grid w-full grid-cols-2">
<TabsList className="grid w-full grid-cols-3">
<TabsTrigger value="realtime">Speech to Speech</TabsTrigger>
<TabsTrigger value="dograh">Dograh</TabsTrigger>
<TabsTrigger value="byok">BYOK</TabsTrigger>
</TabsList>
<TabsContent value="realtime" className="mt-0">
<p className="mb-4 text-sm text-muted-foreground">
A single speech-to-speech model handles the conversation in realtime (no separate transcriber or voice). An LLM is still required for variable extraction and QA.
</p>
<ServiceConfigurationForm
key={`realtime-${JSON.stringify(realtimeInitialConfig)}`}
mode="global"
forceRealtime
configurationDefaults={defaultsForByok}
initialConfig={realtimeInitialConfig}
submitLabel={submitLabel}
onSave={saveByokConfiguration}
/>
</TabsContent>
<TabsContent value="dograh" className="mt-0">
<div className="rounded-lg border p-5">
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2 sm:col-span-2">
<Label htmlFor="dograh-api-key">API Key</Label>
<div className="relative">
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input
id="dograh-api-key"
className="pl-9"
value={dograh.api_key}
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
placeholder="Enter API key"
/>
</div>
</div>
<div className="space-y-2">
<Label>Voice</Label>
<Select value={dograh.voice} onValueChange={(voice) => setDograh({ ...dograh, voice })}>
@ -394,6 +411,20 @@ export function AIModelConfigurationV2Editor({
</SelectContent>
</Select>
</div>
<div className="space-y-2 sm:col-span-2">
<Label htmlFor="dograh-api-key">API Key</Label>
<div className="relative">
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input
id="dograh-api-key"
className="pl-9"
value={dograh.api_key}
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
placeholder="Enter API key"
/>
</div>
</div>
</div>
<Button type="button" className="mt-6 w-full" onClick={saveDograhConfiguration} disabled={isSavingDograh}>
@ -405,10 +436,11 @@ export function AIModelConfigurationV2Editor({
<TabsContent value="byok" className="mt-0">
<ServiceConfigurationForm
key={JSON.stringify(byokInitialConfig)}
key={`byok-${JSON.stringify(pipelineInitialConfig)}`}
mode="global"
forceRealtime={false}
configurationDefaults={defaultsForByok}
initialConfig={byokInitialConfig}
initialConfig={pipelineInitialConfig}
submitLabel={submitLabel}
onSave={saveByokConfiguration}
/>

View file

@ -101,6 +101,13 @@ export interface ServiceConfigurationFormProps {
submitLabel?: string;
configurationDefaults?: ServiceConfigurationDefaults | null;
initialConfig?: Record<string, unknown> | null;
/**
* When set, locks the realtime/pipeline mode to this value and hides the
* in-form toggle. The v2 editor uses this to surface realtime
* ("Speech to Speech") and pipeline (BYOK) as separate top-level tabs.
* Leave undefined to keep the user-controllable toggle (legacy + overrides).
*/
forceRealtime?: boolean;
}
function getProviderDisplayName(
@ -130,10 +137,11 @@ export function ServiceConfigurationForm({
submitLabel,
configurationDefaults,
initialConfig,
forceRealtime,
}: ServiceConfigurationFormProps) {
const [apiError, setApiError] = useState<string | null>(null);
const [isSaving, setIsSaving] = useState(false);
const [isRealtime, setIsRealtime] = useState(false);
const [isRealtime, setIsRealtime] = useState(forceRealtime ?? false);
const { userConfig } = useUserConfig();
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
llm: {},
@ -227,9 +235,9 @@ export function ServiceConfigurationForm({
realtime: realtimeSchemas,
});
// Restore realtime toggle
// Restore realtime toggle (skip when the parent locks the mode)
const configData = configSource as Record<string, unknown> | null;
if (configData?.is_realtime) {
if (forceRealtime === undefined && configData?.is_realtime) {
setIsRealtime(true);
}
@ -867,22 +875,24 @@ export function ServiceConfigurationForm({
return (
<form onSubmit={handleSubmit(onSubmit)}>
{/* Realtime toggle */}
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
<div>
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
Realtime Mode
</Label>
<p className="text-xs text-muted-foreground mt-0.5">
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
</p>
{/* Realtime toggle — hidden when the parent locks the mode (v2 tabs) */}
{forceRealtime === undefined && (
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
<div>
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
Realtime Mode
</Label>
<p className="text-xs text-muted-foreground mt-0.5">
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
</p>
</div>
<Switch
id="realtime-toggle"
checked={isRealtime}
onCheckedChange={setIsRealtime}
/>
</div>
<Switch
id="realtime-toggle"
checked={isRealtime}
onCheckedChange={setIsRealtime}
/>
</div>
)}
<Card>
<CardContent className="pt-6">

View file

@ -136,6 +136,11 @@ const NAV_SECTIONS: SidebarNavSection[] = [
url: "/usage",
icon: TrendingUp,
},
{
title: "Billing",
url: "/billing",
icon: CircleDollarSign,
},
{
title: "Reports",
url: "/reports",

View file

@ -0,0 +1,192 @@
'use client';
import { createContext, ReactNode, useCallback, useContext, useEffect, useRef, useState } from 'react';
import { client } from '@/client/client.gen';
import { getCurrentOrganizationContextApiV1OrganizationsContextGet, getUserConfigurationsApiV1UserConfigurationsUserGet, updateUserConfigurationsApiV1UserConfigurationsUserPut } from '@/client/sdk.gen';
import type { OrganizationContextResponse, UserConfigurationRequestResponseSchema } from '@/client/types.gen';
import { setupAuthInterceptor } from '@/lib/apiClient';
import type { AuthUser } from '@/lib/auth';
import { useAuth } from '@/lib/auth';
interface TeamPermission {
id: string;
}
interface OrganizationPricing {
price_per_second_usd: number | null;
currency: string;
billing_enabled: boolean;
}
interface OrgConfigContextType {
orgContext: OrganizationContextResponse | null;
userConfig: UserConfigurationRequestResponseSchema | null;
saveUserConfig: (userConfig: UserConfigurationRequestResponseSchema) => Promise<void>;
loading: boolean;
error: Error | null;
refreshConfig: () => Promise<void>;
permissions: TeamPermission[];
user: AuthUser | null;
organizationPricing: OrganizationPricing | null;
}
const OrgConfigContext = createContext<OrgConfigContextType | null>(null);
const pricingFromUserConfig = (
userConfig: UserConfigurationRequestResponseSchema,
): OrganizationPricing | null => {
if (!userConfig.organization_pricing) {
return null;
}
return {
price_per_second_usd: userConfig.organization_pricing.price_per_second_usd as number | null,
currency: (userConfig.organization_pricing.currency as string) || 'USD',
billing_enabled: (userConfig.organization_pricing.billing_enabled as boolean) || false,
};
};
export function OrgConfigProvider({ children }: { children: ReactNode }) {
const [orgContext, setOrgContext] = useState<OrganizationContextResponse | null>(null);
const [userConfig, setUserConfig] = useState<UserConfigurationRequestResponseSchema | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<Error | null>(null);
const [organizationPricing, setOrganizationPricing] = useState<OrganizationPricing | null>(null);
const [permissions, setPermissions] = useState<TeamPermission[]>([]);
const auth = useAuth();
const authRef = useRef(auth);
authRef.current = auth;
const hasFetchedConfig = useRef(false);
const hasFetchedPermissions = useRef(false);
if (!auth.loading && auth.isAuthenticated) {
setupAuthInterceptor(client, auth.getAccessToken);
}
useEffect(() => {
if (auth.loading || hasFetchedPermissions.current) {
return;
}
hasFetchedPermissions.current = true;
const fetchPermissions = async () => {
const currentAuth = authRef.current;
if (currentAuth.provider === 'stack' && currentAuth.getSelectedTeam && currentAuth.listPermissions) {
const selectedTeam = currentAuth.getSelectedTeam();
if (selectedTeam) {
try {
const perms = await currentAuth.listPermissions(selectedTeam);
setPermissions(Array.isArray(perms) ? perms : []);
} catch {
setPermissions([]);
}
} else {
setPermissions([]);
}
} else {
setPermissions([{ id: 'admin' }]);
}
};
fetchPermissions();
}, [auth.loading, auth.provider]);
const fetchConfig = useCallback(async () => {
const currentAuth = authRef.current;
if (!currentAuth.isAuthenticated) {
return;
}
setLoading(true);
try {
const [orgContextResponse, userConfigResponse] = await Promise.all([
getCurrentOrganizationContextApiV1OrganizationsContextGet(),
getUserConfigurationsApiV1UserConfigurationsUserGet(),
]);
if (orgContextResponse.data) {
setOrgContext(orgContextResponse.data);
}
if (userConfigResponse.data) {
setUserConfig(userConfigResponse.data);
setOrganizationPricing(pricingFromUserConfig(userConfigResponse.data));
}
setError(null);
} catch (err) {
setError(err instanceof Error ? err : new Error('Failed to fetch organization configuration'));
} finally {
setLoading(false);
}
}, []);
useEffect(() => {
if (auth.loading || !auth.isAuthenticated || hasFetchedConfig.current) {
return;
}
hasFetchedConfig.current = true;
fetchConfig();
}, [auth.loading, auth.isAuthenticated, fetchConfig]);
const saveUserConfig = useCallback(async (userConfigRequest: UserConfigurationRequestResponseSchema) => {
if (!authRef.current.isAuthenticated) throw new Error('No authentication available');
const response = await updateUserConfigurationsApiV1UserConfigurationsUserPut({
body: {
...userConfig,
...userConfigRequest,
} as UserConfigurationRequestResponseSchema,
});
if (response.error) {
let msg = 'Failed to save user configuration';
const detail = (response.error as unknown as { detail?: string | { errors: { model: string; message: string }[] } }).detail;
if (typeof detail === 'string') {
msg = detail;
} else if (Array.isArray(detail)) {
msg = detail
.map((e: { model: string; message: string }) => `${e.model}: ${e.message}`)
.join('\n');
}
throw new Error(msg);
}
if (response.data) {
setUserConfig(response.data);
setOrganizationPricing(pricingFromUserConfig(response.data));
}
}, [userConfig]);
const refreshConfig = useCallback(async () => {
await fetchConfig();
}, [fetchConfig]);
return (
<OrgConfigContext.Provider
value={{
orgContext,
userConfig,
saveUserConfig,
loading,
error,
refreshConfig,
permissions,
user: auth.user,
organizationPricing,
}}
>
{children}
</OrgConfigContext.Provider>
);
}
export function useOrgConfig() {
const context = useContext(OrgConfigContext);
if (!context) {
throw new Error('useOrgConfig must be used within an OrgConfigProvider');
}
return context;
}

View file

@ -1,205 +1,3 @@
'use client';
import { createContext, ReactNode, useCallback, useContext, useEffect, useRef, useState } from 'react';
import { client } from '@/client/client.gen';
import { getUserConfigurationsApiV1UserConfigurationsUserGet, updateUserConfigurationsApiV1UserConfigurationsUserPut } from '@/client/sdk.gen';
import type { UserConfigurationRequestResponseSchema } from '@/client/types.gen';
import { setupAuthInterceptor } from '@/lib/apiClient';
import type { AuthUser } from '@/lib/auth';
import { useAuth } from '@/lib/auth';
interface TeamPermission {
id: string;
}
interface OrganizationPricing {
price_per_second_usd: number | null;
currency: string;
billing_enabled: boolean;
}
interface UserConfigContextType {
userConfig: UserConfigurationRequestResponseSchema | null;
saveUserConfig: (userConfig: UserConfigurationRequestResponseSchema) => Promise<void>;
loading: boolean;
error: Error | null;
refreshConfig: () => Promise<void>;
permissions: TeamPermission[];
user: AuthUser | null;
organizationPricing: OrganizationPricing | null;
}
const UserConfigContext = createContext<UserConfigContextType | null>(null);
export function UserConfigProvider({ children }: { children: ReactNode }) {
const [userConfig, setUserConfig] = useState<UserConfigurationRequestResponseSchema | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<Error | null>(null);
const [organizationPricing, setOrganizationPricing] = useState<OrganizationPricing | null>(null);
const [permissions, setPermissions] = useState<TeamPermission[]>([]);
const auth = useAuth();
// Store auth functions in refs to avoid dependency issues
const authRef = useRef(auth);
authRef.current = auth;
// Track initialization
const hasFetchedConfig = useRef(false);
const hasFetchedPermissions = useRef(false);
// Register the auth interceptor synchronously during render (not in useEffect)
// so it's in place before any child effects fire API calls.
// setupAuthInterceptor is idempotent — safe for strict mode double-renders.
if (!auth.loading && auth.isAuthenticated) {
setupAuthInterceptor(client, auth.getAccessToken);
}
// Fetch permissions once when auth is ready
useEffect(() => {
if (auth.loading || hasFetchedPermissions.current) {
return;
}
hasFetchedPermissions.current = true;
const fetchPermissions = async () => {
const currentAuth = authRef.current;
if (currentAuth.provider === 'stack' && currentAuth.getSelectedTeam && currentAuth.listPermissions) {
const selectedTeam = currentAuth.getSelectedTeam();
if (selectedTeam) {
try {
const perms = await currentAuth.listPermissions(selectedTeam);
setPermissions(Array.isArray(perms) ? perms : []);
} catch {
setPermissions([]);
}
} else {
setPermissions([]);
}
} else {
setPermissions([{ id: 'admin' }]);
}
};
fetchPermissions();
}, [auth.loading, auth.provider]);
// Fetch user config once when auth is ready
useEffect(() => {
if (auth.loading || !auth.isAuthenticated || hasFetchedConfig.current) {
return;
}
hasFetchedConfig.current = true;
const fetchUserConfig = async () => {
setLoading(true);
try {
const response = await getUserConfigurationsApiV1UserConfigurationsUserGet();
if (response.data) {
setUserConfig(response.data);
if (response.data.organization_pricing) {
setOrganizationPricing({
price_per_second_usd: response.data.organization_pricing.price_per_second_usd as number | null,
currency: response.data.organization_pricing.currency as string || 'USD',
billing_enabled: response.data.organization_pricing.billing_enabled as boolean || false
});
} else {
setOrganizationPricing(null);
}
}
setError(null);
} catch (err) {
setError(err instanceof Error ? err : new Error('Failed to fetch user configuration'));
} finally {
setLoading(false);
}
};
fetchUserConfig();
}, [auth.loading, auth.isAuthenticated]);
const saveUserConfig = useCallback(async (userConfigRequest: UserConfigurationRequestResponseSchema) => {
if (!authRef.current.isAuthenticated) throw new Error('No authentication available');
const response = await updateUserConfigurationsApiV1UserConfigurationsUserPut({
body: {
...userConfig,
...userConfigRequest
} as UserConfigurationRequestResponseSchema,
});
if (response.error) {
let msg = 'Failed to save user configuration';
const detail = (response.error as unknown as { detail?: string | { errors: { model: string; message: string }[] } }).detail;
if (typeof detail === 'string') {
msg = detail;
} else if (Array.isArray(detail)) {
msg = detail
.map((e: { model: string; message: string }) => `${e.model}: ${e.message}`)
.join('\n');
}
throw new Error(msg);
}
setUserConfig(response.data!);
if (response.data?.organization_pricing) {
setOrganizationPricing({
price_per_second_usd: response.data.organization_pricing.price_per_second_usd as number | null,
currency: response.data.organization_pricing.currency as string || 'USD',
billing_enabled: response.data.organization_pricing.billing_enabled as boolean || false
});
}
}, [userConfig]);
const refreshConfig = useCallback(async () => {
const currentAuth = authRef.current;
if (!currentAuth.isAuthenticated) return;
setLoading(true);
try {
const response = await getUserConfigurationsApiV1UserConfigurationsUserGet();
if (response.data) {
setUserConfig(response.data);
if (response.data.organization_pricing) {
setOrganizationPricing({
price_per_second_usd: response.data.organization_pricing.price_per_second_usd as number | null,
currency: response.data.organization_pricing.currency as string || 'USD',
billing_enabled: response.data.organization_pricing.billing_enabled as boolean || false
});
}
}
setError(null);
} catch (err) {
setError(err instanceof Error ? err : new Error('Failed to fetch user configuration'));
} finally {
setLoading(false);
}
}, []);
return (
<UserConfigContext.Provider
value={{
userConfig,
saveUserConfig,
loading,
error,
refreshConfig,
permissions,
user: auth.user,
organizationPricing,
}}
>
{children}
</UserConfigContext.Provider>
);
}
export function useUserConfig() {
const context = useContext(UserConfigContext);
if (!context) {
throw new Error('useUserConfig must be used within a UserConfigProvider');
}
return context;
}
export { OrgConfigProvider as UserConfigProvider, useOrgConfig as useUserConfig } from './OrgConfigContext';

View file

@ -2,18 +2,33 @@
* Extract a human-readable message from a backend error response.
*
* The generated API client returns `{ error }` on failure (it does not throw),
* and FastAPI shapes that error as either `{ detail: string }` (HTTPException)
* or `{ detail: [{ msg, loc, ... }] }` (422 validation). This normalizes both
* to a single string so it can be rendered or thrown directly never pass the
* raw `detail` to React, as the 422 array crashes rendering.
* and FastAPI shapes that error as `{ detail: string }`, `{ detail:
* [{ msg, loc, ... }] }`, or backend validation arrays like `{ detail:
* [{ model, message }] }`. This normalizes those to a single string so it can
* be rendered or thrown directly.
*/
export function detailFromError(err: unknown, fallback = "Request failed"): string {
if (typeof err === "string") return err;
const e = err as { detail?: unknown };
if (typeof e?.detail === "string") return e.detail;
if (Array.isArray(e?.detail) && e.detail.length > 0) {
const first = e.detail[0] as { msg?: string };
if (first?.msg) return first.msg;
const messages = e.detail
.map((item) => {
if (typeof item === "string") return item;
if (!item || typeof item !== "object") return null;
const detail = item as { message?: unknown; msg?: unknown; model?: unknown };
const message = typeof detail.message === "string"
? detail.message
: typeof detail.msg === "string"
? detail.msg
: null;
if (!message) return null;
return typeof detail.model === "string" && detail.model
? `${detail.model}: ${message}`
: message;
})
.filter((message): message is string => Boolean(message));
if (messages.length > 0) return messages.join("\n");
}
return fallback;
}