mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: billing and credit management v2 (#429)
* feat: use mps generated correlation ID * chore: update pipecat submodule * feat: add credit purchase URL * feat: carve out billing page and show credit ledger * feat: deprecate dograh based quota tracking * fix: remove cost calculation from dograh codebase * fix: create mps account on migrate to v2 * chore: update pipecat
This commit is contained in:
parent
97d7103480
commit
1f1149f4d5
80 changed files with 3335 additions and 2057 deletions
|
|
@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
|
|||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# 1) Create the `quota_type` enum *before* we add the column that references it.
|
||||
|
|
@ -34,7 +37,12 @@ def upgrade() -> None:
|
|||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("period_start", sa.DateTime(), nullable=False),
|
||||
sa.Column("period_end", sa.DateTime(), nullable=False),
|
||||
sa.Column("quota_dograh_tokens", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"quota_dograh_tokens",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
sa.Column("used_dograh_tokens", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
|
|
@ -63,7 +71,11 @@ def upgrade() -> None:
|
|||
op.add_column(
|
||||
"organizations",
|
||||
sa.Column(
|
||||
"quota_type", quota_type_enum, nullable=False, server_default="monthly"
|
||||
"quota_type",
|
||||
quota_type_enum,
|
||||
nullable=False,
|
||||
server_default="monthly",
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
|
@ -73,6 +85,7 @@ def upgrade() -> None:
|
|||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
|
@ -82,10 +95,17 @@ def upgrade() -> None:
|
|||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("LEAST(EXTRACT(DAY FROM CURRENT_DATE)::int, 28)"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"organizations", sa.Column("quota_start_date", sa.DateTime(), nullable=True)
|
||||
"organizations",
|
||||
sa.Column(
|
||||
"quota_start_date",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"organizations",
|
||||
|
|
@ -94,6 +114,7 @@ def upgrade() -> None:
|
|||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
|
|||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
|
|
@ -26,7 +29,12 @@ def upgrade() -> None:
|
|||
)
|
||||
op.add_column(
|
||||
"organization_usage_cycles",
|
||||
sa.Column("quota_amount_usd", sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
"quota_amount_usd",
|
||||
sa.Float(),
|
||||
nullable=True,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
|
|||
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
|
||||
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
|
||||
|
||||
class CampaignClient(BaseDBClient):
|
||||
|
|
@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
conditions = [
|
||||
WorkflowRunModel.is_completed.is_(True),
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"]
|
||||
WorkflowRunModel.usage_info["call_duration_seconds"]
|
||||
.as_string()
|
||||
.isnot(None),
|
||||
]
|
||||
|
|
@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient):
|
|||
WorkflowRunModel.initial_context,
|
||||
WorkflowRunModel.gathered_context,
|
||||
WorkflowRunModel.cost_info,
|
||||
WorkflowRunModel.usage_info,
|
||||
WorkflowRunModel.public_access_token,
|
||||
)
|
||||
.where(*conditions)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class DBClient(
|
|||
- UserClient: handles user and user configuration operations
|
||||
- OrganizationClient: handles organization operations
|
||||
- OrganizationConfigurationClient: handles organization configuration operations
|
||||
- OrganizationUsageClient: handles organization usage and quota operations
|
||||
- OrganizationUsageClient: handles organization usage reporting aggregates
|
||||
- IntegrationClient: handles integration operations
|
||||
- WorkflowTemplateClient: handles workflow template operations
|
||||
- CampaignClient: handles campaign operations
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def get_workflow_run_order_clause(
|
|||
"""
|
||||
# Determine sort column
|
||||
if sort_by == "duration":
|
||||
sort_column = WorkflowRunModel.cost_info.op("->>")(
|
||||
sort_column = WorkflowRunModel.usage_info.op("->>")(
|
||||
"call_duration_seconds"
|
||||
).cast(Float)
|
||||
else:
|
||||
|
|
@ -43,7 +43,7 @@ def get_workflow_run_order_clause(
|
|||
ATTRIBUTE_FIELD_MAPPING = {
|
||||
"dateRange": "created_at",
|
||||
"dispositionCode": "gathered_context.mapped_call_disposition",
|
||||
"duration": "cost_info.call_duration_seconds",
|
||||
"duration": "usage_info.call_duration_seconds",
|
||||
"status": "is_completed",
|
||||
"tokenUsage": "cost_info.total_cost_usd",
|
||||
"runId": "id",
|
||||
|
|
@ -208,7 +208,7 @@ def apply_workflow_run_filters(
|
|||
min_val = value.get("min")
|
||||
max_val = value.get("max")
|
||||
|
||||
if field == "cost_info.call_duration_seconds":
|
||||
if field == "usage_info.call_duration_seconds":
|
||||
# Use ->> operator for compatibility with all PostgreSQL versions
|
||||
# (subscript [] only works in PostgreSQL 14+)
|
||||
duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")(
|
||||
|
|
|
|||
|
|
@ -97,22 +97,44 @@ class OrganizationModel(Base):
|
|||
provider_id = Column(String, unique=True, index=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Quota fields
|
||||
# Deprecated: MPS owns quota and credit ledger state.
|
||||
quota_type = Column(
|
||||
Enum("monthly", "annual", name="quota_type"),
|
||||
nullable=False,
|
||||
default="monthly",
|
||||
server_default=text("'monthly'::quota_type"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default=text("0"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_reset_day = Column(
|
||||
Integer, nullable=False, default=1, server_default=text("1")
|
||||
) # 1-28, only for monthly
|
||||
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=1,
|
||||
server_default=text("1"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_start_date = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
server_default=text("false"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
|
||||
price_per_second_usd = Column(Float, nullable=True)
|
||||
|
|
@ -593,8 +615,9 @@ class WorkflowRunTextSessionModel(Base):
|
|||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
cycle.
|
||||
This model is used to track reporting aggregates for an organization for a given
|
||||
usage cycle. Quota fields on this model are deprecated; MPS owns quota and
|
||||
credit ledger state.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_usage_cycles"
|
||||
|
|
@ -603,14 +626,24 @@ class OrganizationUsageCycleModel(Base):
|
|||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
period_start = Column(DateTime(timezone=True), nullable=False)
|
||||
period_end = Column(DateTime(timezone=True), nullable=False)
|
||||
quota_dograh_tokens = Column(Integer, nullable=False)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
used_dograh_tokens = Column(Float, nullable=False, default=0)
|
||||
total_duration_seconds = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
)
|
||||
# New USD tracking fields
|
||||
used_amount_usd = Column(Float, nullable=True, default=0)
|
||||
quota_amount_usd = Column(Float, nullable=True)
|
||||
quota_amount_usd = Column(
|
||||
Float,
|
||||
nullable=True,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ from api.db.models import (
|
|||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class OrganizationUsageClient(BaseDBClient):
|
||||
"""Client for managing organization usage and quota operations."""
|
||||
"""Client for managing organization usage reporting aggregates."""
|
||||
|
||||
async def get_or_create_current_cycle(
|
||||
self, organization_id: int, session=None
|
||||
|
|
@ -49,14 +49,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
self, organization_id: int, session, commit: bool
|
||||
) -> OrganizationUsageCycleModel:
|
||||
"""Internal implementation for get_or_create_current_cycle."""
|
||||
# Get organization to determine quota type
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# Calculate current period
|
||||
period_start, period_end = self._calculate_current_period(org)
|
||||
period_start, period_end = self._calculate_current_period()
|
||||
|
||||
# Try to get existing cycle
|
||||
cycle_result = await session.execute(
|
||||
|
|
@ -78,7 +71,8 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
organization_id=organization_id,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
quota_dograh_tokens=org.quota_dograh_tokens,
|
||||
# Deprecated non-null column retained for historical schema compatibility.
|
||||
quota_dograh_tokens=0,
|
||||
)
|
||||
# Handle concurrent inserts gracefully
|
||||
stmt = stmt.on_conflict_do_nothing(
|
||||
|
|
@ -102,95 +96,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
return cycle_result.scalar_one()
|
||||
|
||||
async def check_and_reserve_quota(
|
||||
self, organization_id: int, estimated_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if organization has sufficient quota and optionally reserve tokens.
|
||||
Returns True if quota is available, False otherwise.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if not org or not org.quota_enabled:
|
||||
# No quota enforcement if not enabled
|
||||
return True
|
||||
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Atomic check and update with row-level lock
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(
|
||||
and_(
|
||||
OrganizationUsageCycleModel.id == cycle.id,
|
||||
OrganizationUsageCycleModel.used_dograh_tokens
|
||||
+ estimated_tokens
|
||||
<= OrganizationUsageCycleModel.quota_dograh_tokens,
|
||||
)
|
||||
)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
|
||||
cycle_locked = result.scalar_one_or_none()
|
||||
if cycle_locked:
|
||||
# Update the usage atomically
|
||||
cycle_locked.used_dograh_tokens += estimated_tokens
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def update_usage_after_run(
|
||||
self,
|
||||
organization_id: int,
|
||||
actual_tokens: float,
|
||||
duration_seconds: float = 0,
|
||||
charge_usd: float | None = None,
|
||||
) -> None:
|
||||
"""Update usage after a workflow run completes with actual token count and duration.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Acquire a row-level lock for atomic update
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(OrganizationUsageCycleModel.id == cycle.id)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
cycle_locked = result.scalar_one()
|
||||
|
||||
# Update usage atomically
|
||||
cycle_locked.used_dograh_tokens += actual_tokens
|
||||
cycle_locked.total_duration_seconds += int(round(duration_seconds))
|
||||
|
||||
# Update USD amount if provided
|
||||
if charge_usd is not None:
|
||||
if cycle_locked.used_amount_usd is None:
|
||||
cycle_locked.used_amount_usd = 0
|
||||
cycle_locked.used_amount_usd += charge_usd
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def get_current_usage(self, organization_id: int) -> dict:
|
||||
"""Get current period usage information."""
|
||||
"""Get current reporting-period usage information."""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
|
|
@ -201,42 +109,19 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Calculate next refresh date
|
||||
if org.quota_type == "monthly":
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
else: # annual
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
|
||||
result = {
|
||||
"period_start": cycle.period_start.isoformat(),
|
||||
"period_end": cycle.period_end.isoformat(),
|
||||
"used_dograh_tokens": cycle.used_dograh_tokens,
|
||||
"quota_dograh_tokens": cycle.quota_dograh_tokens,
|
||||
"percentage_used": (
|
||||
round(
|
||||
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
|
||||
)
|
||||
if cycle.quota_dograh_tokens > 0
|
||||
else 0
|
||||
),
|
||||
"next_refresh_date": next_refresh.date().isoformat(),
|
||||
"quota_enabled": org.quota_enabled,
|
||||
"total_duration_seconds": cycle.total_duration_seconds,
|
||||
}
|
||||
|
||||
# Add USD fields if organization has pricing
|
||||
if org.price_per_second_usd is not None:
|
||||
result["used_amount_usd"] = cycle.used_amount_usd or 0
|
||||
result["quota_amount_usd"] = cycle.quota_amount_usd
|
||||
result["currency"] = "USD"
|
||||
result["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Calculate percentage based on USD if available
|
||||
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
|
||||
result["percentage_used"] = round(
|
||||
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_usage_history(
|
||||
|
|
@ -256,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
WorkflowRunModel.usage_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
|
@ -309,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
total_tokens = 0
|
||||
total_duration_seconds = 0
|
||||
for run in runs:
|
||||
if run.cost_info:
|
||||
# Try to get dograh_token_usage first (new format)
|
||||
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
|
||||
# If not present, calculate from total_cost_usd (old format)
|
||||
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
|
||||
dograh_tokens = round(
|
||||
float(run.cost_info["total_cost_usd"]) * 100, 2
|
||||
)
|
||||
# Get call duration
|
||||
call_duration = run.cost_info.get("call_duration_seconds", 0)
|
||||
else:
|
||||
dograh_tokens = 0
|
||||
call_duration = 0
|
||||
dograh_tokens = 0
|
||||
call_duration = (run.usage_info or {}).get("call_duration_seconds", 0)
|
||||
total_tokens += dograh_tokens
|
||||
total_duration_seconds += int(round(call_duration))
|
||||
|
||||
|
|
@ -395,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
WorkflowRunModel.initial_context,
|
||||
WorkflowRunModel.gathered_context,
|
||||
WorkflowRunModel.cost_info,
|
||||
WorkflowRunModel.usage_info,
|
||||
WorkflowRunModel.public_access_token,
|
||||
)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
WorkflowRunModel.usage_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
|
@ -473,11 +348,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
if config_obj and config_obj.configuration:
|
||||
user_config = EffectiveAIModelConfiguration.model_validate(
|
||||
effective_config = EffectiveAIModelConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = user_config.timezone
|
||||
if effective_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = effective_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
try:
|
||||
|
|
@ -496,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
select(
|
||||
date_expr.label("date"),
|
||||
func.sum(
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
|
||||
WorkflowRunModel.usage_info["call_duration_seconds"].as_float()
|
||||
).label("total_seconds"),
|
||||
func.count(WorkflowRunModel.id).label("call_count"),
|
||||
)
|
||||
|
|
@ -545,83 +420,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"currency": "USD",
|
||||
}
|
||||
|
||||
async def update_organization_quota(
|
||||
self,
|
||||
organization_id: int,
|
||||
quota_type: str,
|
||||
quota_dograh_tokens: int,
|
||||
quota_reset_day: Optional[int] = None,
|
||||
quota_start_date: Optional[datetime] = None,
|
||||
) -> OrganizationModel:
|
||||
"""Update organization quota settings."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = result.scalar_one()
|
||||
|
||||
org.quota_type = quota_type
|
||||
org.quota_dograh_tokens = quota_dograh_tokens
|
||||
org.quota_enabled = True
|
||||
|
||||
if quota_type == "monthly" and quota_reset_day:
|
||||
org.quota_reset_day = quota_reset_day
|
||||
elif quota_type == "annual" and quota_start_date:
|
||||
org.quota_start_date = quota_start_date
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
def _calculate_current_period(
|
||||
self, org: OrganizationModel
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current billing period based on organization settings."""
|
||||
def _calculate_current_period(self) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current calendar-month reporting period."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if org.quota_type == "monthly":
|
||||
# Find the start of the current billing month
|
||||
reset_day = org.quota_reset_day
|
||||
|
||||
# Handle month boundaries
|
||||
if now.day >= reset_day:
|
||||
period_start = now.replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
else:
|
||||
# Previous month
|
||||
period_start = (now - relativedelta(months=1)).replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
# End is one month later minus 1 second
|
||||
period_end = (
|
||||
period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
)
|
||||
|
||||
else: # annual
|
||||
if not org.quota_start_date:
|
||||
# Default to calendar year
|
||||
period_start = now.replace(
|
||||
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
else:
|
||||
# Find current annual period
|
||||
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
|
||||
years_diff = now.year - start_date.year
|
||||
|
||||
# Adjust for whether we've passed the anniversary
|
||||
if now.month < start_date.month or (
|
||||
now.month == start_date.month and now.day < start_date.day
|
||||
):
|
||||
years_diff -= 1
|
||||
|
||||
period_start = start_date + relativedelta(years=years_diff)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
|
||||
return period_start, period_end
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class UserClient(BaseDBClient):
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from api.db.models import (
|
|||
)
|
||||
from api.enums import CallType, StorageBackend
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
|
||||
|
||||
class WorkflowRunClient(BaseDBClient):
|
||||
|
|
@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ async def search_chunks(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
|
|
@ -392,17 +392,17 @@ async def search_chunks(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
if effective_config.embeddings:
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.constants import (
|
||||
DEFAULT_CAMPAIGN_RETRY_CONFIG,
|
||||
DEFAULT_ORG_CONCURRENCY_LIMIT,
|
||||
DEPLOYMENT_MODE,
|
||||
)
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
|
|
@ -55,6 +59,11 @@ from api.services.configuration.registry import (
|
|||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
|
|
@ -129,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
|
|||
telnyx_missing_webhook_public_key_count: int
|
||||
|
||||
|
||||
@router.get("/context", response_model=OrganizationContextResponse)
|
||||
async def get_current_organization_context(user: UserModel = Depends(get_user)):
|
||||
"""Return organization-scoped configuration signals owned by Dograh."""
|
||||
return await get_organization_context(user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/telephony-providers/metadata",
|
||||
response_model=TelephonyProvidersMetadataResponse,
|
||||
|
|
@ -349,6 +364,23 @@ async def migrate_model_configuration_v2(
|
|||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize MPS billing v2 account for organization {}: {}",
|
||||
organization_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to initialize MPS billing v2 account",
|
||||
)
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.reports import generate_usage_runs_report_csv
|
||||
from api.utils.artifacts import artifact_url
|
||||
|
|
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
|
|||
period_start: str
|
||||
period_end: str
|
||||
used_dograh_tokens: float
|
||||
quota_dograh_tokens: int
|
||||
percentage_used: float
|
||||
next_refresh_date: str
|
||||
quota_enabled: bool
|
||||
total_duration_seconds: int
|
||||
# New USD fields
|
||||
used_amount_usd: Optional[float] = None
|
||||
quota_amount_usd: Optional[float] = None
|
||||
currency: Optional[str] = None
|
||||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
|
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
|
|||
total_quota: float
|
||||
|
||||
|
||||
class MPSCreditPurchaseUrlResponse(BaseModel):
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class MPSBillingAccountResponse(BaseModel):
|
||||
id: int
|
||||
organization_id: int
|
||||
billing_mode: str
|
||||
cached_balance_credits: float
|
||||
currency: str
|
||||
|
||||
|
||||
class MPSCreditLedgerEntryResponse(BaseModel):
|
||||
id: int
|
||||
entry_type: str
|
||||
origin: Optional[str] = None
|
||||
credits_delta: float
|
||||
balance_after: float
|
||||
amount_minor: Optional[int] = None
|
||||
amount_currency: Optional[str] = None
|
||||
payment_order_id: Optional[int] = None
|
||||
metric_code: Optional[str] = None
|
||||
correlation_id: Optional[str] = None
|
||||
aggregation_key: Optional[str] = None
|
||||
usage_event_id: Optional[int] = None
|
||||
workflow_run_id: Optional[int] = None
|
||||
workflow_id: Optional[int] = None
|
||||
billable_quantity: Optional[float] = None
|
||||
quantity_unit: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class MPSBillingCreditsResponse(BaseModel):
|
||||
billing_version: Literal["legacy", "v2"]
|
||||
total_credits_used: float = 0.0
|
||||
remaining_credits: float = 0.0
|
||||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
total_count: int = 0
|
||||
page: int = 1
|
||||
limit: int = 50
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
def _optional_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
|
|||
|
||||
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
|
||||
async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
||||
"""Get current billing period usage for the user's organization."""
|
||||
"""Get current reporting-period usage for the user's organization."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
|
||||
|
|
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _get_mps_billing_account_status(
|
||||
user: UserModel, organization_id: int
|
||||
) -> Optional[dict]:
|
||||
return await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
|
||||
|
||||
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = float(usage.get("total_credits_used", 0.0))
|
||||
total_remaining = float(usage.get("remaining_credits", 0.0))
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="legacy",
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
total_count = int(ledger.get("total_count") or len(ledger_entries))
|
||||
response_limit = int(ledger.get("limit") or limit)
|
||||
total_pages = int(
|
||||
ledger.get("total_pages")
|
||||
or ((total_count + response_limit - 1) // response_limit)
|
||||
)
|
||||
workflow_ids_by_run_id: dict[int, int] = {}
|
||||
workflow_run_ids = {
|
||||
workflow_run_id
|
||||
for entry in ledger_entries
|
||||
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
|
||||
is not None
|
||||
}
|
||||
for workflow_run_id in workflow_run_ids:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if (
|
||||
workflow_run
|
||||
and workflow_run.workflow
|
||||
and workflow_run.workflow.organization_id == organization_id
|
||||
):
|
||||
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
|
||||
|
||||
balance = float(account.get("cached_balance_credits") or 0.0)
|
||||
total_debits = sum(
|
||||
abs(float(entry.get("credits_delta") or 0.0))
|
||||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
if ledger.get("total_debits_credits") is not None:
|
||||
total_debits = float(ledger["total_debits_credits"])
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
total_credits_used=total_debits,
|
||||
remaining_credits=balance,
|
||||
total_quota=balance + total_debits,
|
||||
account=MPSBillingAccountResponse(
|
||||
id=int(account["id"]),
|
||||
organization_id=int(account["organization_id"]),
|
||||
billing_mode=str(account["billing_mode"]),
|
||||
cached_balance_credits=balance,
|
||||
currency=str(account.get("currency") or "USD"),
|
||||
),
|
||||
ledger_entries=[
|
||||
MPSCreditLedgerEntryResponse(
|
||||
id=int(entry["id"]),
|
||||
entry_type=str(entry["entry_type"]),
|
||||
origin=entry.get("origin"),
|
||||
credits_delta=float(entry.get("credits_delta") or 0.0),
|
||||
balance_after=float(entry.get("balance_after") or 0.0),
|
||||
amount_minor=entry.get("amount_minor"),
|
||||
amount_currency=entry.get("amount_currency"),
|
||||
payment_order_id=entry.get("payment_order_id"),
|
||||
metric_code=entry.get("metric_code"),
|
||||
correlation_id=entry.get("correlation_id"),
|
||||
aggregation_key=entry.get("aggregation_key"),
|
||||
usage_event_id=_optional_int(entry.get("usage_event_id")),
|
||||
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
|
||||
workflow_id=workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None,
|
||||
billable_quantity=float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None,
|
||||
quantity_unit=entry.get("quantity_unit"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
total_count=total_count,
|
||||
page=int(ledger.get("page") or page),
|
||||
limit=response_limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch billing credits: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/mps-credits/purchase-url",
|
||||
response_model=MPSCreditPurchaseUrlResponse,
|
||||
)
|
||||
async def create_mps_credit_purchase_url(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credit purchases are not available in OSS mode",
|
||||
)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
assert organization_id is not None
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
"Credit purchases are available only for organizations using billing v2"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
session = await mps_service_key_client.create_credit_purchase_url(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
|
||||
billing_details={
|
||||
"source": "dograh_billing",
|
||||
"dograh_user_id": str(user.id),
|
||||
"dograh_provider_id": str(user.provider_id),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to create credit purchase URL",
|
||||
)
|
||||
|
||||
checkout_url = session.get("checkout_url")
|
||||
if not checkout_url:
|
||||
logger.error(f"MPS checkout session response missing checkout_url: {session}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="MPS checkout session response missing checkout_url",
|
||||
)
|
||||
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
FILTERS_DESCRIPTION = """\
|
||||
JSON-encoded array of filter objects. Each object has the shape:
|
||||
|
||||
|
|
|
|||
|
|
@ -41,12 +41,15 @@ from api.services.configuration.resolve import (
|
|||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.reports import generate_workflow_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.run_usage_response import (
|
||||
format_public_cost_info,
|
||||
format_public_usage_info,
|
||||
)
|
||||
from api.services.workflow.trigger_paths import (
|
||||
TriggerPathIssue,
|
||||
ensure_trigger_paths,
|
||||
|
|
@ -1053,13 +1056,15 @@ async def update_workflow(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
effective_config,
|
||||
)
|
||||
effective = resolve_effective_config(
|
||||
effective_config, enriched_overrides
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
|
|
@ -1264,22 +1269,7 @@ async def get_workflow_run(
|
|||
"transcript_public_url": artifact_url(public_access_token, "transcript"),
|
||||
"recording_public_url": artifact_url(public_access_token, "recording"),
|
||||
"public_access_token": public_access_token,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info and "dograh_token_usage" in run.cost_info
|
||||
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds"))
|
||||
)
|
||||
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
"created_at": run.created_at,
|
||||
"definition_id": run.definition_id,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
DograhEmbeddingsConfiguration,
|
||||
DograhLLMService,
|
||||
|
|
@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default"
|
|||
DOGRAH_DEFAULT_LANGUAGE = "multi"
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
managed_service_version: int | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
||||
|
||||
class DograhManagedAIModelConfiguration(BaseModel):
|
||||
api_key: str
|
||||
voice: str = DOGRAH_DEFAULT_VOICE
|
||||
|
|
@ -160,6 +183,7 @@ def _compile_dograh_configuration(
|
|||
model="default",
|
||||
),
|
||||
is_realtime=False,
|
||||
managed_service_version=2,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
|
@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
|
@ -110,6 +111,19 @@ async def get_user(
|
|||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization.id,
|
||||
created_by=str(stack_user["id"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize hosted MPS billing account for "
|
||||
"organization {}",
|
||||
organization.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
|
|
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
|
|
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
|
|
@ -285,8 +299,8 @@ async def create_user_configuration_with_mps_key(
|
|||
"model": "default",
|
||||
},
|
||||
}
|
||||
user_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return user_config
|
||||
effective_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return effective_config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to get MPS service key: {response.status_code} - {response.text}"
|
||||
|
|
|
|||
|
|
@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
|
|||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from groq import Groq
|
|||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
from api.schemas.ai_model_configuration import (
|
||||
EffectiveAIModelConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
|
|
@ -75,21 +75,21 @@ class UserConfigurationValidator:
|
|||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The rules are simple:
|
|||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
base_url: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
if default_headers:
|
||||
client_kwargs["default_headers"] = default_headers
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
|
|
|
|||
98
api/services/managed_model_services.py
Normal file
98
api/services/managed_model_services.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
||||
def uses_managed_model_services_v2(
|
||||
ai_model_config: EffectiveAIModelConfiguration | None,
|
||||
) -> bool:
|
||||
if (
|
||||
ai_model_config is None
|
||||
or getattr(ai_model_config, "managed_service_version", None) != 2
|
||||
):
|
||||
return False
|
||||
|
||||
return any(
|
||||
_is_dograh_service(getattr(ai_model_config, section_name, None))
|
||||
for section_name in ("llm", "tts", "stt", "embeddings")
|
||||
)
|
||||
|
||||
|
||||
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
|
||||
if not initial_context:
|
||||
return None
|
||||
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
|
||||
if correlation_id is None:
|
||||
return None
|
||||
return str(correlation_id)
|
||||
|
||||
|
||||
async def ensure_mps_correlation_id(
|
||||
*,
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
workflow_run_id: int,
|
||||
initial_context: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
existing = get_mps_correlation_id(initial_context)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
service_key = _get_dograh_service_api_key(ai_model_config)
|
||||
if not service_key:
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires a Dograh service key before the run starts."
|
||||
)
|
||||
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
correlation_id = response.get("correlation_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("MPS correlation-id response did not include correlation_id")
|
||||
|
||||
correlation_id = str(correlation_id)
|
||||
logger.info(
|
||||
"Minted MPS correlation id {} for workflow run {}",
|
||||
correlation_id,
|
||||
workflow_run_id,
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
service = getattr(ai_model_config, section_name, None)
|
||||
if not _is_dograh_service(service):
|
||||
continue
|
||||
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if isinstance(api_key, str) and api_key:
|
||||
return api_key
|
||||
|
||||
return None
|
||||
23
api/services/mps_billing.py
Normal file
23
api/services/mps_billing.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
async def ensure_hosted_mps_billing_account_v2(
|
||||
organization_id: int,
|
||||
*,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Ensure hosted orgs have an MPS billing v2 account.
|
||||
|
||||
OSS deployments use legacy per-key quota accounting and do not create MPS
|
||||
billing accounts.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return None
|
||||
|
||||
return await mps_service_key_client.ensure_billing_account_v2(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
|
|||
Service keys are stored and managed entirely in MPS, not in the local database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -353,6 +354,234 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def create_credit_purchase_url(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
return_url: Optional[str] = None,
|
||||
billing_details: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create a short-lived MPS checkout URL for adding organization credits."""
|
||||
payload = {
|
||||
"created_by": created_by,
|
||||
"return_url": return_url,
|
||||
"billing_details": billing_details or {},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create MPS credit purchase URL: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create MPS credit purchase URL: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get the MPS v2 billing account balance and recent credit ledger."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"page": page, "limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS credit ledger: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS credit ledger: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_billing_account_status(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Get an existing MPS v2 billing account without creating one."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS billing account status: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS billing account status: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def ensure_billing_account_v2(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create or return the MPS v2 billing account for an organization."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to ensure MPS billing account v2: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to ensure MPS billing account v2: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
service_key: str,
|
||||
workflow_run_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Mint a server-generated correlation ID for managed model services."""
|
||||
payload: dict[str, int] = {}
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {service_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create correlation ID: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create correlation ID: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def report_platform_usage(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
correlation_id: Optional[str] = None,
|
||||
duration_seconds: Optional[float] = None,
|
||||
workflow_run_id: int | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
max_attempts: int = 3,
|
||||
) -> dict:
|
||||
"""Report hosted Dograh platform usage for a completed workflow run."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise ValueError("OSS deployments must not report platform usage to MPS")
|
||||
if not correlation_id and duration_seconds is None:
|
||||
raise ValueError(
|
||||
"Platform usage reports require correlation_id or duration_seconds"
|
||||
)
|
||||
|
||||
payload: dict = {
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if correlation_id:
|
||||
payload["correlation_id"] = correlation_id
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
max_attempts = max(1, max_attempts)
|
||||
last_response: httpx.Response | None = None
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
response = await client.post(
|
||||
(
|
||||
f"{self.base_url}/api/v1/billing/accounts/"
|
||||
f"{organization_id}/platform-usage"
|
||||
),
|
||||
json=payload,
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
last_response = response
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
should_retry = (
|
||||
response.status_code == 409
|
||||
and "usage_not_ready" in response.text
|
||||
and attempt < max_attempts
|
||||
)
|
||||
if should_retry:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
"Failed to report platform usage: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to report platform usage: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
raise httpx.HTTPStatusError(
|
||||
"Failed to report platform usage",
|
||||
request=last_response.request,
|
||||
response=last_response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
|
|
|
|||
50
api/services/organization_context.py
Normal file
50
api/services/organization_context.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
|
||||
class OrganizationModelServicesContext(BaseModel):
|
||||
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
has_model_configuration_v2: bool
|
||||
managed_service_version: Optional[int] = None
|
||||
uses_managed_service_v2: bool
|
||||
|
||||
|
||||
class OrganizationContextResponse(BaseModel):
|
||||
organization_id: Optional[int] = None
|
||||
organization_provider_id: Optional[str] = None
|
||||
model_services: OrganizationModelServicesContext
|
||||
|
||||
|
||||
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
|
||||
organization_id = user.selected_organization_id
|
||||
organization = (
|
||||
await db_client.get_organization_by_id(organization_id)
|
||||
if organization_id
|
||||
else None
|
||||
)
|
||||
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
managed_service_version = resolved.effective.managed_service_version
|
||||
|
||||
return OrganizationContextResponse(
|
||||
organization_id=organization_id,
|
||||
organization_provider_id=organization.provider_id if organization else None,
|
||||
model_services=OrganizationModelServicesContext(
|
||||
config_source=resolved.source,
|
||||
has_model_configuration_v2=resolved.source == "organization_v2",
|
||||
managed_service_version=managed_service_version,
|
||||
uses_managed_service_v2=(
|
||||
resolved.source == "organization_v2" and managed_service_version == 2
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
@ -162,15 +162,13 @@ async def run_pipeline_telephony(
|
|||
workflow_id: Workflow being executed.
|
||||
workflow_run_id: Workflow run row.
|
||||
user_id: Owner of the workflow.
|
||||
call_id: Provider call identifier (stored in cost_info for billing).
|
||||
call_id: Provider call identifier.
|
||||
transport_kwargs: Provider-specific kwargs forwarded to the transport
|
||||
factory (e.g. stream_sid + call_sid for Twilio).
|
||||
"""
|
||||
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
|
||||
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if workflow:
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
|
@ -340,7 +338,7 @@ async def _run_pipeline(
|
|||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
merged_call_context_vars = dict(workflow_run.initial_context or {})
|
||||
# If there is some extra call_context_vars, fold them in. Persistence
|
||||
# happens once below, after runtime_configuration is also resolved.
|
||||
if call_context_vars:
|
||||
|
|
@ -398,6 +396,19 @@ async def _run_pipeline(
|
|||
else:
|
||||
user_config = resolved_user_config
|
||||
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=merged_call_context_vars,
|
||||
)
|
||||
if mps_correlation_id:
|
||||
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
|
|
@ -409,11 +420,23 @@ async def _run_pipeline(
|
|||
# Realtime services don't implement run_inference, so create a
|
||||
# separate text LLM for variable extraction and other out-of-band
|
||||
# inference calls.
|
||||
inference_llm = create_llm_service(user_config)
|
||||
inference_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
stt = create_stt_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
keyterms=keyterms,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
tts = create_tts_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = None
|
||||
|
||||
# Stamp the providers/models actually resolved for this run onto
|
||||
|
|
@ -695,7 +718,10 @@ async def _run_pipeline(
|
|||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
voicemail_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
|
|
|
|||
|
|
@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
|||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
user_config,
|
||||
audio_config: "AudioConfig",
|
||||
keyterms: list[str] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
):
|
||||
"""Create and return appropriate STT service based on user configuration
|
||||
|
||||
|
|
@ -160,6 +163,7 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
|
|
@ -286,7 +290,9 @@ def create_stt_service(
|
|||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
def create_tts_service(
|
||||
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
|
||||
):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
|
|
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
|
|
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
|
|||
model: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
correlation_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
|
|
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
|
|||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
|
|
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
def create_llm_service(user_config, correlation_id: str | None = None):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
model = user_config.llm.model
|
||||
|
|
@ -880,4 +889,10 @@ def create_llm_service(user_config):
|
|||
elif provider == ServiceProviders.SARVAM.value:
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
return create_llm_service_from_provider(
|
||||
provider,
|
||||
model,
|
||||
api_key,
|
||||
correlation_id=correlation_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
# Pricing Module
|
||||
|
||||
This module contains pricing models and registries for different AI services used in workflow cost calculations.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
pricing/
|
||||
├── __init__.py # Main module exports
|
||||
├── models.py # Base pricing model classes
|
||||
├── llm.py # LLM pricing configurations
|
||||
├── tts.py # TTS pricing configurations
|
||||
├── stt.py # STT pricing configurations
|
||||
├── registry.py # Combined pricing registry
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Pricing Models
|
||||
|
||||
### TokenPricingModel
|
||||
Used for LLM services that charge based on tokens:
|
||||
- `prompt_token_price`: Cost per prompt token
|
||||
- `completion_token_price`: Cost per completion token
|
||||
- `cache_read_discount`: Discount for cache read tokens (default 50%)
|
||||
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
|
||||
|
||||
### CharacterPricingModel
|
||||
Used for TTS services that charge based on character count:
|
||||
- `character_price`: Cost per character
|
||||
|
||||
### TimePricingModel
|
||||
Used for STT services that charge based on time:
|
||||
- `second_price`: Cost per second
|
||||
|
||||
## Adding New Pricing
|
||||
|
||||
### Adding a New LLM Model
|
||||
Edit `llm.py` and add the model to the appropriate provider:
|
||||
|
||||
```python
|
||||
ServiceProviders.OPENAI: {
|
||||
"new-model": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000,
|
||||
completion_token_price=Decimal("8.00") / 1000000,
|
||||
),
|
||||
# ... existing models
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a New Provider
|
||||
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
|
||||
2. The registry will automatically include them
|
||||
|
||||
### Adding a New Service Type
|
||||
1. Create a new pricing file (e.g., `image.py`)
|
||||
2. Define the pricing models
|
||||
3. Import and add to `registry.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The pricing registry is automatically imported and used by the cost calculator:
|
||||
|
||||
```python
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.workflow.cost_calculator import cost_calculator
|
||||
|
||||
# The cost calculator uses the pricing registry automatically
|
||||
result = cost_calculator.calculate_total_cost(usage_info)
|
||||
```
|
||||
|
||||
## Maintenance
|
||||
|
||||
- Update pricing when providers change their rates
|
||||
- All prices should use `Decimal` for precision
|
||||
- Include comments with current pricing from provider documentation
|
||||
- Test changes with existing test suite
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""
|
||||
Pricing module for workflow cost calculation.
|
||||
|
||||
This module contains pricing models and registries for different AI services.
|
||||
"""
|
||||
|
||||
from .registry import PRICING_REGISTRY
|
||||
|
||||
__all__ = ["PRICING_REGISTRY"]
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
"""
|
||||
Cost Calculator for Workflow Runs
|
||||
|
||||
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
|
||||
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
|
||||
|
||||
Features:
|
||||
- Token-based pricing for LLM services with cache optimization support
|
||||
- Character-based pricing for TTS services
|
||||
- Time-based pricing for STT services
|
||||
- Configurable pricing models that can be updated
|
||||
- Support for multiple providers and models
|
||||
- Automatic provider inference from model names
|
||||
- JSON serialization support for database storage
|
||||
|
||||
Usage:
|
||||
from api.tasks.cost_calculator import cost_calculator
|
||||
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"processor_name|||gpt-4o": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"total_tokens": 1500,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"processor_name|||aura-2-helena-en": 2000 # character count
|
||||
}
|
||||
}
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
print(f"Total cost: ${cost_breakdown['total']:.6f}")
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.pricing.models import (
|
||||
PricingModel,
|
||||
)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""Main cost calculator class"""
|
||||
|
||||
def __init__(self, pricing_registry: Dict = None):
|
||||
self.pricing_registry = pricing_registry or PRICING_REGISTRY
|
||||
|
||||
def get_pricing_model(
|
||||
self, service_type: str, provider: str, model: str
|
||||
) -> Optional[PricingModel]:
|
||||
"""Get pricing model for a specific service, provider, and model"""
|
||||
try:
|
||||
service_pricing = self.pricing_registry.get(service_type, {})
|
||||
|
||||
# Try to get pricing for the specific provider
|
||||
provider_pricing = service_pricing.get(provider, {})
|
||||
pricing_model = provider_pricing.get(model) or provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
if pricing_model:
|
||||
return pricing_model
|
||||
|
||||
# If not found, try the "default" provider for this service type
|
||||
default_provider_pricing = service_pricing.get("default", {})
|
||||
return default_provider_pricing.get(model) or default_provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
except (KeyError, AttributeError):
|
||||
return None
|
||||
|
||||
def calculate_llm_cost(
|
||||
self, provider: str, model: str, usage: Dict[str, int]
|
||||
) -> Decimal:
|
||||
"""Calculate cost for LLM usage"""
|
||||
pricing_model = self.get_pricing_model("llm", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(usage)
|
||||
|
||||
def calculate_tts_cost(
|
||||
self, provider: str, model: str, character_count: int
|
||||
) -> Decimal:
|
||||
"""Calculate cost for TTS usage"""
|
||||
pricing_model = self.get_pricing_model("tts", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(character_count)
|
||||
|
||||
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT usage"""
|
||||
pricing_model = self.get_pricing_model("stt", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(seconds)
|
||||
|
||||
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
|
||||
llm_cost_total = Decimal("0")
|
||||
tts_cost_total = Decimal("0")
|
||||
stt_cost_total = Decimal("0")
|
||||
|
||||
# Calculate LLM costs
|
||||
llm_usage = usage_info.get("llm", {})
|
||||
for key, usage in llm_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Try to determine provider from processor name or model
|
||||
provider = self._infer_provider_from_model(model, "llm")
|
||||
cost = self.calculate_llm_cost(provider, model, usage)
|
||||
llm_cost_total += cost
|
||||
|
||||
# Calculate TTS costs
|
||||
tts_usage = usage_info.get("tts", {})
|
||||
for key, character_count in tts_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Handle the case where model is "None" - infer from processor
|
||||
if model.lower() in ["none", "null", ""]:
|
||||
provider = self._infer_provider_from_processor(processor, "tts")
|
||||
model = "default" # Use default model for the provider
|
||||
else:
|
||||
provider = self._infer_provider_from_model(model, "tts")
|
||||
cost = self.calculate_tts_cost(provider, model, character_count)
|
||||
tts_cost_total += cost
|
||||
|
||||
# Calculate STT costs from explicit stt usage
|
||||
stt_usage = usage_info.get("stt", {})
|
||||
for key, seconds in stt_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
provider = self._infer_provider_from_model(model, "stt")
|
||||
cost = self.calculate_stt_cost(provider, model, seconds)
|
||||
stt_cost_total += cost
|
||||
|
||||
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
|
||||
|
||||
return {
|
||||
"llm_cost": float(llm_cost_total),
|
||||
"tts_cost": float(tts_cost_total),
|
||||
"stt_cost": float(stt_cost_total),
|
||||
"total": float(total_cost),
|
||||
}
|
||||
|
||||
def _parse_key(self, key) -> Tuple[str, str]:
|
||||
"""Parse key which is in format 'processor|||model'"""
|
||||
if isinstance(key, str) and "|||" in key:
|
||||
parts = key.split("|||", 1)
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
# Fallback for backwards compatibility or malformed keys
|
||||
return str(key), "unknown"
|
||||
|
||||
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
|
||||
"""Infer provider from model name"""
|
||||
if not model:
|
||||
return "unknown"
|
||||
|
||||
model_lower = model.lower()
|
||||
|
||||
# OpenAI models
|
||||
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq models
|
||||
if any(keyword in model_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Elevenlabs models
|
||||
if any(keyword in model_lower for keyword in ["eleven"]):
|
||||
return ServiceProviders.ELEVENLABS
|
||||
|
||||
# Deepgram models
|
||||
if any(
|
||||
keyword in model_lower
|
||||
for keyword in ["deepgram", "nova", "phonecall", "general"]
|
||||
):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
|
||||
"""Infer provider from processor name"""
|
||||
if not processor:
|
||||
return "unknown"
|
||||
|
||||
processor_lower = processor.lower()
|
||||
|
||||
# OpenAI processors
|
||||
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq processors
|
||||
if any(keyword in processor_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Deepgram processors
|
||||
if any(keyword in processor_lower for keyword in ["deepgram"]):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def update_pricing(
|
||||
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
|
||||
):
|
||||
"""Update pricing for a specific service/provider/model combination"""
|
||||
if service_type not in self.pricing_registry:
|
||||
self.pricing_registry[service_type] = {}
|
||||
if provider not in self.pricing_registry[service_type]:
|
||||
self.pricing_registry[service_type][provider] = {}
|
||||
self.pricing_registry[service_type][provider][model] = pricing_model
|
||||
|
||||
|
||||
# Global cost calculator instance
|
||||
cost_calculator = CostCalculator()
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
"""
|
||||
LLM pricing models for different providers.
|
||||
|
||||
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TokenPricingModel
|
||||
|
||||
# LLM pricing registry
|
||||
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-3.5-turbo": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
|
||||
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
|
||||
),
|
||||
"gpt-4": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
|
||||
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
|
||||
),
|
||||
"gpt-4.1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
|
||||
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-nano": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
|
||||
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
),
|
||||
"gpt-4.5-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
|
||||
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
|
||||
completion_token_price=Decimal("10.00")
|
||||
/ 1000000, # $10.00 per 1M tokens - FIXED
|
||||
),
|
||||
"gpt-4o-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
|
||||
),
|
||||
"gpt-4o-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"o1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
|
||||
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
|
||||
),
|
||||
"o1-pro": TokenPricingModel(
|
||||
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
|
||||
),
|
||||
"o1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o3": TokenPricingModel(
|
||||
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
|
||||
),
|
||||
"o3-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o4-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"computer-use-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
|
||||
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
|
||||
),
|
||||
"gpt-image-1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
|
||||
),
|
||||
"codex-mini-latest": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
|
||||
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
|
||||
),
|
||||
# Transcription models
|
||||
"gpt-4o-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
|
||||
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
),
|
||||
# TTS models with token-based pricing
|
||||
"gpt-4o-mini-tts": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("0")
|
||||
/ 1000000, # No completion tokens for TTS
|
||||
),
|
||||
},
|
||||
ServiceProviders.GROQ: {
|
||||
"llama-3.3-70b-versatile": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
|
||||
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
|
||||
),
|
||||
"deepseek-r1-distill-llama-70b": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
|
||||
completion_token_price=Decimal("0.00079") / 1000,
|
||||
),
|
||||
},
|
||||
ServiceProviders.AZURE: {
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("8.80")
|
||||
/ 1000000, # $1.60 per 1M tokens if using data zone
|
||||
)
|
||||
},
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
Base pricing models for different service types.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CostType(Enum):
|
||||
LLM_TOKENS = "llm_tokens"
|
||||
TTS_CHARACTERS = "tts_characters"
|
||||
STT_SECONDS = "stt_seconds"
|
||||
|
||||
|
||||
class PricingModel:
|
||||
"""Base class for pricing models"""
|
||||
|
||||
def calculate_cost(self, usage: Any) -> Decimal:
|
||||
"""Calculate cost based on usage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenPricingModel(PricingModel):
|
||||
"""Pricing model for token-based services (LLM)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_price: Decimal,
|
||||
completion_token_price: Decimal,
|
||||
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
|
||||
cache_creation_multiplier: Decimal = Decimal(
|
||||
"1.25"
|
||||
), # 25% premium for cache creation
|
||||
):
|
||||
self.prompt_token_price = prompt_token_price
|
||||
self.completion_token_price = completion_token_price
|
||||
self.cache_read_discount = cache_read_discount
|
||||
self.cache_creation_multiplier = cache_creation_multiplier
|
||||
|
||||
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
|
||||
"""Calculate cost for LLM token usage"""
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
|
||||
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
|
||||
|
||||
# Base cost
|
||||
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
|
||||
completion_cost = Decimal(completion_tokens) * self.completion_token_price
|
||||
|
||||
# Cache adjustments
|
||||
cache_read_savings = (
|
||||
Decimal(cache_read_tokens)
|
||||
* self.prompt_token_price
|
||||
* self.cache_read_discount
|
||||
)
|
||||
cache_creation_premium = (
|
||||
Decimal(cache_creation_tokens)
|
||||
* self.prompt_token_price
|
||||
* (self.cache_creation_multiplier - 1)
|
||||
)
|
||||
|
||||
total_cost = (
|
||||
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
|
||||
)
|
||||
return max(total_cost, Decimal("0")) # Ensure non-negative
|
||||
|
||||
|
||||
class CharacterPricingModel(PricingModel):
|
||||
"""Pricing model for character-based services (TTS)"""
|
||||
|
||||
def __init__(self, character_price: Decimal):
|
||||
self.character_price = character_price
|
||||
|
||||
def calculate_cost(self, character_count: int) -> Decimal:
|
||||
"""Calculate cost for TTS character usage"""
|
||||
return Decimal(character_count) * self.character_price
|
||||
|
||||
|
||||
class TimePricingModel(PricingModel):
|
||||
"""Pricing model for time-based services (STT)"""
|
||||
|
||||
def __init__(self, second_price: Decimal):
|
||||
self.second_price = second_price
|
||||
|
||||
def calculate_cost(self, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT time usage"""
|
||||
return Decimal(str(seconds)) * self.second_price
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Main pricing registry that combines all service type pricing models.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
||||
# Combined pricing registry for all service types
|
||||
PRICING_REGISTRY: Dict = {
|
||||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
STT (Speech-to-Text) pricing models for different providers.
|
||||
|
||||
Prices are per second for STT services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TimePricingModel
|
||||
|
||||
# STT pricing registry
|
||||
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
|
||||
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
|
||||
"default": TimePricingModel(Decimal("0.0077") / 60),
|
||||
},
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
|
||||
"default": TimePricingModel(Decimal("0.015") / 60),
|
||||
},
|
||||
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""
|
||||
TTS (Text-to-Speech) pricing models for different providers.
|
||||
|
||||
Prices are per character for TTS services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import CharacterPricingModel
|
||||
|
||||
# TTS pricing registry
|
||||
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
},
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
|
||||
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
},
|
||||
ServiceProviders.ELEVENLABS: {
|
||||
# 6400 usd per 250*1e6 characters
|
||||
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
|
||||
},
|
||||
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
|
||||
}
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider_for_run
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _get_pricing_organization(workflow_run):
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
organization_id = getattr(workflow, "organization_id", None)
|
||||
if organization_id is None and workflow and workflow.user:
|
||||
organization_id = workflow.user.selected_organization_id
|
||||
if organization_id is None:
|
||||
return None
|
||||
return await db_client.get_organization_by_id(organization_id)
|
||||
|
||||
|
||||
async def _build_usage_cost_snapshot(
|
||||
usage_info: dict | None,
|
||||
*,
|
||||
workflow_run=None,
|
||||
include_telephony_cost: bool = False,
|
||||
organization=None,
|
||||
calculated_at: str | None = None,
|
||||
) -> dict | None:
|
||||
if not usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return None
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
|
||||
if include_telephony_cost and workflow_run is not None:
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
total_cost_usd = Decimal(str(cost_breakdown["total"]))
|
||||
dograh_tokens = float(total_cost_usd * Decimal("100"))
|
||||
|
||||
if organization is None and workflow_run is not None:
|
||||
organization = await _get_pricing_organization(workflow_run)
|
||||
|
||||
charge_usd = None
|
||||
if organization and organization.price_per_second_usd:
|
||||
duration_seconds = usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = float(
|
||||
Decimal(str(duration_seconds))
|
||||
* Decimal(str(organization.price_per_second_usd))
|
||||
)
|
||||
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(total_cost_usd),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": calculated_at
|
||||
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
|
||||
}
|
||||
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = organization.price_per_second_usd
|
||||
|
||||
return cost_info
|
||||
|
||||
|
||||
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
|
||||
cost_info = await _build_usage_cost_snapshot(
|
||||
workflow_run.usage_info,
|
||||
workflow_run=workflow_run,
|
||||
include_telephony_cost=True,
|
||||
calculated_at=workflow_run.created_at.isoformat(),
|
||||
)
|
||||
if cost_info is None:
|
||||
return None
|
||||
return {
|
||||
**(workflow_run.cost_info or {}),
|
||||
**cost_info,
|
||||
}
|
||||
|
||||
|
||||
async def save_workflow_run_cost_info(
|
||||
workflow_run_id: int, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
|
||||
async def apply_workflow_run_usage_to_organization(
|
||||
workflow_run, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
|
||||
|
||||
async def apply_usage_delta_to_organization(
|
||||
workflow_run, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return None
|
||||
|
||||
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
|
||||
if cost_info is None:
|
||||
return None
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
return cost_info
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
try:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
await save_workflow_run_cost_info(workflow_run_id, cost_info)
|
||||
|
||||
try:
|
||||
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
|
||||
except Exception as e:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if org:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to update organization usage: {e}")
|
||||
# Don't fail the whole cost calculation if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
|
|
@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
for run in runs:
|
||||
initial = run.initial_context or {}
|
||||
gathered = run.gathered_context or {}
|
||||
cost = run.cost_info or {}
|
||||
usage = run.usage_info or {}
|
||||
|
||||
call_tags = gathered.get("call_tags", [])
|
||||
if isinstance(call_tags, list):
|
||||
|
|
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
run.created_at.isoformat() if run.created_at else "",
|
||||
initial.get("phone_number", ""),
|
||||
gathered.get("mapped_call_disposition", ""),
|
||||
cost.get("call_duration_seconds", ""),
|
||||
usage.get("call_duration_seconds", ""),
|
||||
]
|
||||
|
||||
extracted = gathered.get("extracted_variables", {})
|
||||
|
|
|
|||
|
|
@ -66,34 +66,6 @@ async def handle_vonage_events(
|
|||
logger.error(f"[run {workflow_run_id}] Workflow run not found")
|
||||
return {"status": "error", "message": "Workflow run not found"}
|
||||
|
||||
# For a completed call that includes cost info, capture it immediately
|
||||
if event_data.get("status") == "completed":
|
||||
# Vonage sometimes includes price info in the webhook
|
||||
if "price" in event_data or "rate" in event_data:
|
||||
try:
|
||||
if workflow_run.cost_info:
|
||||
# Store immediate cost info if available
|
||||
cost_info = workflow_run.cost_info.copy()
|
||||
if "price" in event_data:
|
||||
cost_info["vonage_webhook_price"] = float(event_data["price"])
|
||||
if "rate" in event_data:
|
||||
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
|
||||
if "duration" in event_data:
|
||||
cost_info["vonage_webhook_duration"] = int(
|
||||
event_data["duration"]
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, cost_info=cost_info
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
|
||||
)
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import asyncio
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
|
|
@ -382,6 +383,9 @@ class PipecatEngine:
|
|||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
correlation_id=self._call_context_vars.get(
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
),
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
41
api/services/workflow/run_usage_response.py
Normal file
41
api/services/workflow/run_usage_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
||||
|
||||
def format_public_cost_info(
|
||||
cost_info: dict | None, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
"""Return the legacy response shape without doing local cost accounting."""
|
||||
duration = None
|
||||
if usage_info and usage_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(usage_info.get("call_duration_seconds") or 0))
|
||||
elif cost_info and cost_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(cost_info.get("call_duration_seconds") or 0))
|
||||
|
||||
dograh_token_usage = 0
|
||||
if cost_info:
|
||||
if "dograh_token_usage" in cost_info:
|
||||
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
|
||||
elif "total_cost_usd" in cost_info:
|
||||
dograh_token_usage = round(
|
||||
float(cost_info.get("total_cost_usd", 0)) * 100, 2
|
||||
)
|
||||
|
||||
if duration is None and dograh_token_usage == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"dograh_token_usage": dograh_token_usage,
|
||||
"call_duration_seconds": duration,
|
||||
}
|
||||
|
|
@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
|
|||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
base_initial_context = dict(workflow_run.initial_context or {})
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=base_initial_context,
|
||||
)
|
||||
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
|
|
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
|
|||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
**base_initial_context,
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
if mps_correlation_id:
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
|
|
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
|
|||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ async def _perform_retrieval(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -272,11 +278,20 @@ async def _perform_retrieval(
|
|||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
111
api/services/workflow_run_billing.py
Normal file
111
api/services/workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Workflow-run billing hooks.
|
||||
|
||||
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
|
||||
For hosted deployments, Dograh reports completed platform usage to MPS.
|
||||
When a server-minted MPS correlation id exists, MPS uses model-service usage
|
||||
as the canonical duration. Otherwise Dograh reports the completed run duration.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.services.managed_model_services import get_mps_correlation_id
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
def _workflow_run_organization_id(workflow_run) -> int | None:
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
return getattr(workflow, "organization_id", None)
|
||||
|
||||
|
||||
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
|
||||
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
|
||||
duration = usage_info.get("call_duration_seconds")
|
||||
try:
|
||||
duration_seconds = float(duration)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return duration_seconds if duration_seconds > 0 else None
|
||||
|
||||
|
||||
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
|
||||
account = await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id
|
||||
)
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def report_workflow_run_platform_usage(workflow_run) -> None:
|
||||
"""Report hosted platform usage for a completed workflow run to MPS."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return
|
||||
|
||||
if not getattr(workflow_run, "is_completed", False):
|
||||
return
|
||||
|
||||
organization_id = _workflow_run_organization_id(workflow_run)
|
||||
if organization_id is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no organization_id",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
correlation_id = get_mps_correlation_id(
|
||||
getattr(workflow_run, "initial_context", None)
|
||||
)
|
||||
duration_seconds = (
|
||||
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
|
||||
)
|
||||
if not correlation_id and duration_seconds is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no billable duration",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if not await _organization_uses_mps_billing_v2(organization_id):
|
||||
return
|
||||
|
||||
result = await mps_service_key_client.report_platform_usage(
|
||||
organization_id=organization_id,
|
||||
correlation_id=correlation_id,
|
||||
duration_seconds=duration_seconds,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": getattr(workflow_run, "workflow_id", None),
|
||||
"duration_source": (
|
||||
"mps_correlation" if correlation_id else "dograh_usage_info"
|
||||
),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Reported platform usage for workflow run {} to MPS: {}",
|
||||
workflow_run.id,
|
||||
result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
|
||||
"""Load a completed workflow run and report platform usage to MPS."""
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Skipping platform usage report: workflow run {} not found",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
|
@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import (
|
|||
)
|
||||
from api.tasks.knowledge_base_processing import process_knowledge_base_document
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
process_workflow_completion,
|
||||
upload_voicemail_audio_to_s3,
|
||||
)
|
||||
from api.tasks.s3_upload import upload_voicemail_audio_to_s3
|
||||
from api.tasks.workflow_completion import process_workflow_completion
|
||||
|
||||
|
||||
class WorkerSettings:
|
||||
|
|
|
|||
|
|
@ -166,18 +166,22 @@ async def process_knowledge_base_document(
|
|||
user_id=document.created_by,
|
||||
organization_id=document.organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
effective_config = resolved_config.effective
|
||||
if effective_config.embeddings:
|
||||
embeddings_provider = getattr(
|
||||
effective_config.embeddings, "provider", None
|
||||
)
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(
|
||||
effective_config.embeddings, "endpoint", None
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
logger.info(
|
||||
f"Using user embeddings config: provider={embeddings_provider}, "
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
||||
async def upload_voicemail_audio_to_s3(
|
||||
|
|
@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3(
|
|||
logger.warning(
|
||||
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
This task combines audio upload, transcript upload, and webhook integrations
|
||||
into a single sequential task to ensure integrations run after uploads complete.
|
||||
|
||||
Args:
|
||||
_ctx: ARQ context (unused)
|
||||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Processing workflow completion for run {workflow_run_id}")
|
||||
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
try:
|
||||
if os.path.exists(transcript_temp_path):
|
||||
file_size = os.path.getsize(transcript_temp_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
logger.info(
|
||||
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded transcript: {transcript_url}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcript temp file not found: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if transcript_temp_path and os.path.exists(transcript_temp_path):
|
||||
try:
|
||||
os.remove(transcript_temp_path)
|
||||
logger.debug(
|
||||
f"Cleaned up temp transcript file: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Calculate cost after integrations (so QA token usage is included)
|
||||
try:
|
||||
await calculate_workflow_run_cost(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
|
|||
121
api/tasks/workflow_completion.py
Normal file
121
api/tasks/workflow_completion.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
)
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
This task combines audio upload, transcript upload, and webhook integrations
|
||||
into a single sequential task to ensure integrations run after uploads complete.
|
||||
|
||||
Args:
|
||||
_ctx: ARQ context (unused)
|
||||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Processing workflow completion for run {workflow_run_id}")
|
||||
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
try:
|
||||
if os.path.exists(transcript_temp_path):
|
||||
file_size = os.path.getsize(transcript_temp_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
logger.info(
|
||||
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded transcript: {transcript_url}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcript temp file not found: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if transcript_temp_path and os.path.exists(transcript_temp_path):
|
||||
try:
|
||||
os.remove(transcript_temp_path)
|
||||
logger.debug(
|
||||
f"Cleaned up temp transcript file: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Notify MPS after completion. MPS owns credit accounting.
|
||||
try:
|
||||
await report_completed_workflow_run_platform_usage(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error reporting platform usage for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
@ -203,7 +203,7 @@ async def create_workflow_run_rows(
|
|||
Returns:
|
||||
Tuple of (workflow_run, user, workflow).
|
||||
"""
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
|
||||
async_session.add(org)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
|
|
@ -15,6 +19,7 @@ from api.services.configuration.ai_model_configuration import (
|
|||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
|
|
@ -22,6 +27,8 @@ from api.services.configuration.registry import (
|
|||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleLLMService,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
|
@ -49,6 +56,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
|||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
|
|
@ -92,6 +100,67 @@ def test_byok_v2_rejects_dograh_provider():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
|
||||
config = OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "realtime",
|
||||
"realtime": {
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "google-realtime-key",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Puck",
|
||||
"language": "en",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "google",
|
||||
"api_key": "google-llm-key",
|
||||
"model": "gemini-2.0-flash",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is True
|
||||
assert effective.stt is None
|
||||
assert effective.tts is None
|
||||
assert await UserConfigurationValidator().validate(effective) == {
|
||||
"status": [{"model": "all", "message": "ok"}]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
|
||||
effective = EffectiveAIModelConfiguration(
|
||||
llm=GoogleLLMService(
|
||||
provider="google",
|
||||
api_key="google-llm-key",
|
||||
model="gemini-2.0-flash",
|
||||
),
|
||||
realtime=GoogleRealtimeLLMConfiguration(
|
||||
provider="google_realtime",
|
||||
api_key="google-realtime-key",
|
||||
model="gemini-3.1-flash-live-preview",
|
||||
voice="Puck",
|
||||
language="en",
|
||||
),
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await UserConfigurationValidator().validate(effective)
|
||||
|
||||
assert exc_info.value.args[0] == [
|
||||
{"model": "stt", "message": "API key is missing"},
|
||||
{"model": "tts", "message": "API key is missing"},
|
||||
]
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
|
|
@ -293,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
|||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
from api.routes import organization as organization_routes
|
||||
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
)
|
||||
expected_response = OrganizationAIModelConfigurationResponse(
|
||||
configuration={"version": 2, "mode": "dograh"},
|
||||
effective_configuration={},
|
||||
source="organization_v2",
|
||||
)
|
||||
|
||||
class FakeValidator:
|
||||
async def validate(self, *args, **kwargs):
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
upsert = AsyncMock()
|
||||
migrate_workflows = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"get_organization_ai_model_configuration_v2",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=legacy),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"UserConfigurationValidator",
|
||||
lambda: FakeValidator(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"upsert_organization_ai_model_configuration_v2",
|
||||
upsert,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"migrate_workflow_model_configurations_to_v2",
|
||||
migrate_workflows,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"_model_configuration_v2_response",
|
||||
AsyncMock(return_value=expected_response),
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_routes.migrate_model_configuration_v2(
|
||||
force=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
|
||||
upsert.assert_awaited_once()
|
||||
migrate_workflows.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
assert response == expected_response
|
||||
|
|
|
|||
68
api/tests/test_auth_depends.py
Normal file
68
api/tests/test_auth_depends.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.auth import depends as auth_depends
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
|
||||
stack_user = {
|
||||
"id": "stack-user-1",
|
||||
"selected_team_id": "team-1",
|
||||
"primary_email_verified": False,
|
||||
}
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
email=None,
|
||||
provider_id="stack-user-1",
|
||||
selected_organization_id=None,
|
||||
)
|
||||
organization = SimpleNamespace(id=42)
|
||||
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
|
||||
monkeypatch.setattr(
|
||||
auth_depends.stackauth,
|
||||
"get_user",
|
||||
AsyncMock(return_value=stack_user),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_user_by_provider_id",
|
||||
AsyncMock(return_value=(user, False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_organization_by_provider_id",
|
||||
AsyncMock(return_value=(organization, True)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"add_user_to_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"update_user_selected_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=existing_config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
|
||||
result = await auth_depends.get_user(authorization="Bearer token")
|
||||
|
||||
assert result is user
|
||||
assert result.selected_organization_id == 42
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
110
api/tests/test_dograh_managed_correlation.py
Normal file
110
api/tests/test_dograh_managed_correlation.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
from pipecat.frames.frames import TTSStartedFrame
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from websockets.protocol import State
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.state = State.OPEN
|
||||
self.messages: list[dict] = []
|
||||
|
||||
async def send(self, message: str) -> None:
|
||||
self.messages.append(json.loads(message))
|
||||
|
||||
async def close(self, *args, **kwargs) -> None:
|
||||
self.state = State.CLOSED
|
||||
|
||||
|
||||
def test_dograh_llm_uses_explicit_mps_correlation_id():
|
||||
service = DograhLLMService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
settings=OpenAILLMSettings(model="default"),
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
params = service.build_chat_completion_params(
|
||||
{
|
||||
"messages": [],
|
||||
"tools": OPENAI_NOT_GIVEN,
|
||||
"tool_choice": OPENAI_NOT_GIVEN,
|
||||
}
|
||||
)
|
||||
|
||||
assert params["metadata"]["correlation_id"] == "mps-corr-123"
|
||||
assert params["metadata"]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.stt.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhSTTService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=16000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.tts.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhTTSService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=24000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
async def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service.audio_context_available = lambda context_id: False
|
||||
service.create_audio_context = _noop
|
||||
service.start_ttfb_metrics = _noop
|
||||
service.start_tts_usage_metrics = _noop
|
||||
|
||||
frames = []
|
||||
async for frame in service.run_tts("hello", "ctx-1"):
|
||||
frames.append(frame)
|
||||
|
||||
assert isinstance(frames[0], TTSStartedFrame)
|
||||
assert fake_ws.messages[1]["type"] == "create_context"
|
||||
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[1]["mps_billing_version"] == "2"
|
||||
|
|
@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
|||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
|
|
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
|
|
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
|
|
|
|||
|
|
@ -87,3 +87,317 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"correlation_id": "mps-corr-123"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.create_correlation_id(
|
||||
service_key="mps_sk_paid",
|
||||
workflow_run_id=42,
|
||||
) == {"correlation_id": "mps-corr-123"}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
{"workflow_run_id": 42},
|
||||
{
|
||||
"Authorization": "Bearer mps_sk_paid",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_billing_account_status(organization_id=42) == {
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/status",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.ensure_billing_account_v2(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/balance",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, params, headers):
|
||||
calls.append(("GET", url, params, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_credit_ledger(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
) == {
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
|
||||
{"page": 3, "limit": 25},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"correlation_id": "mps-corr-123",
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"duration_seconds": 87.0,
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
|||
99
api/tests/test_organization_usage_billing.py
Normal file
99
api/tests/test_organization_usage_billing.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.routes import organization_usage
|
||||
|
||||
|
||||
def test_is_mps_billing_v2_depends_only_on_account_mode():
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
|
||||
assert organization_usage._is_mps_billing_v2(None) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(provider_id="provider-123")
|
||||
|
||||
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
|
||||
"billing_mode": "v2"
|
||||
}
|
||||
get_status.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
|
||||
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_usage,
|
||||
"_get_mps_billing_account_status",
|
||||
AsyncMock(return_value={"billing_mode": "v2"}),
|
||||
)
|
||||
get_ledger = AsyncMock(
|
||||
return_value={
|
||||
"account": {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": 250,
|
||||
"currency": "USD",
|
||||
},
|
||||
"ledger_entries": [
|
||||
{
|
||||
"id": 99,
|
||||
"entry_type": "grant",
|
||||
"origin": "account_creation",
|
||||
"credits_delta": 250,
|
||||
"balance_after": 250,
|
||||
"created_at": "2026-06-12T00:00:00Z",
|
||||
}
|
||||
],
|
||||
"total_debits_credits": 75,
|
||||
"total_count": 101,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 5,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_credit_ledger",
|
||||
get_ledger,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_usage.get_billing_credits(
|
||||
page=3,
|
||||
limit=25,
|
||||
user=user,
|
||||
)
|
||||
|
||||
get_ledger.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
created_by="provider-123",
|
||||
)
|
||||
assert response.billing_version == "v2"
|
||||
assert response.total_credits_used == 75
|
||||
assert response.total_count == 101
|
||||
assert response.page == 3
|
||||
assert response.limit == 25
|
||||
assert response.total_pages == 5
|
||||
assert response.ledger_entries[0].id == 99
|
||||
|
|
@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
|
|||
|
||||
import pytest
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
contains_masked_key,
|
||||
mask_workflow_configurations,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.workflow.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
|||
from websockets.exceptions import ConnectionClosedError
|
||||
from websockets.frames import Close
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
_RESUMPTION_USER_MESSAGE,
|
||||
|
|
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_ultravox_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
|
|
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
212
api/tests/test_workflow_run_billing.py
Normal file
212
api/tests/test_workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import workflow_run_billing as workflow_run_billing_mod
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
report_workflow_run_platform_usage,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
is_completed=True,
|
||||
initial_context={"mps_correlation_id": "mps-corr-123"},
|
||||
usage_info={"call_duration_seconds": 87},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
duration_seconds=None,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "mps_correlation",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id=None,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "dograh_usage_info",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v1"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_awaited_once_with(organization_id=42)
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
workflow_run.usage_info = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_not_awaited()
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.is_completed = False
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_run = AsyncMock(return_value=workflow_run)
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
get_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_completed_workflow_run_platform_usage(workflow_run.id)
|
||||
|
||||
get_run.assert_awaited_once_with(workflow_run.id)
|
||||
report_usage.assert_awaited_once()
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
calculate_workflow_run_cost,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
mode="textchat",
|
||||
created_at=datetime.now(UTC),
|
||||
usage_info={
|
||||
"llm": {},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
},
|
||||
cost_info={},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
|
||||
assert cost_info is not None
|
||||
assert cost_info["call_duration_seconds"] == 7
|
||||
assert "cost_breakdown" in cost_info
|
||||
assert "dograh_token_usage" in cost_info
|
||||
assert cost_info["charge_usd"] == 10.5
|
||||
update_usage.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
|
||||
update_run = AsyncMock()
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
await calculate_workflow_run_cost(workflow_run.id)
|
||||
|
||||
update_run.assert_awaited_once()
|
||||
saved_kwargs = update_run.await_args.kwargs
|
||||
assert saved_kwargs["run_id"] == workflow_run.id
|
||||
assert "cost_breakdown" in saved_kwargs["cost_info"]
|
||||
update_usage.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.cost_info = {"call_id": "preserve-me"}
|
||||
|
||||
usage_delta_one = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 1_000,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 1_100,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 3,
|
||||
}
|
||||
usage_delta_two = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 2_000,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 2_050,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 4,
|
||||
}
|
||||
merged_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 3_000,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 3_150,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
}
|
||||
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
|
||||
second_delta = await apply_usage_delta_to_organization(
|
||||
workflow_run, usage_delta_two
|
||||
)
|
||||
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
|
||||
total_workflow_run.usage_info = merged_usage
|
||||
total_cost = await build_workflow_run_cost_info(total_workflow_run)
|
||||
|
||||
assert first_delta is not None
|
||||
assert second_delta is not None
|
||||
assert total_cost is not None
|
||||
assert update_usage.await_count == 2
|
||||
assert update_usage.await_args_list[0].args == (
|
||||
42,
|
||||
first_delta["dograh_token_usage"],
|
||||
3.0,
|
||||
first_delta["charge_usd"],
|
||||
)
|
||||
assert update_usage.await_args_list[1].args == (
|
||||
42,
|
||||
second_delta["dograh_token_usage"],
|
||||
4.0,
|
||||
second_delta["charge_usd"],
|
||||
)
|
||||
assert (
|
||||
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
|
||||
) == pytest.approx(total_cost["dograh_token_usage"])
|
||||
assert (
|
||||
first_delta["charge_usd"] + second_delta["charge_usd"]
|
||||
== total_cost["charge_usd"]
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from api.db.models import OrganizationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
|
@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
|||
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
|
||||
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.cost_info[
|
||||
"call_duration_seconds"
|
||||
] == workflow_run.usage_info.get("call_duration_seconds", 0)
|
||||
assert "cost_breakdown" in workflow_run.cost_info
|
||||
assert "dograh_token_usage" in workflow_run.cost_info
|
||||
assert "call_duration_seconds" in workflow_run.usage_info
|
||||
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
|
||||
"Hello from the workflow tester."
|
||||
]
|
||||
|
|
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
|
|||
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
|
||||
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.cost_info[
|
||||
"call_duration_seconds"
|
||||
] == workflow_run.usage_info.get("call_duration_seconds", 0)
|
||||
assert "cost_breakdown" in workflow_run.cost_info
|
||||
assert "dograh_token_usage" in workflow_run.cost_info
|
||||
assert "call_duration_seconds" in workflow_run.usage_info
|
||||
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
|
||||
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
|
||||
"Welcome to the workflow tester.",
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 228324a146a6765c6b8d610963bc80d7bc8cb9f7
|
||||
Subproject commit 0d64dc6e0e3e6b3c46cc66373e34b4f54f980268
|
||||
416
ui/src/app/billing/page.tsx
Normal file
416
ui/src/app/billing/page.tsx
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
ChevronLeft,
|
||||
ChevronRight,
|
||||
CircleDollarSign,
|
||||
CreditCard,
|
||||
RefreshCw,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import { createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost, getBillingCreditsApiV1OrganizationsBillingCreditsGet } from "@/client/sdk.gen";
|
||||
import type { MpsBillingCreditsResponse, MpsCreditLedgerEntryResponse } from "@/client/types.gen";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Progress } from "@/components/ui/progress";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { useAppConfig } from "@/context/AppConfigContext";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
const LEDGER_PAGE_SIZE = 50;
|
||||
|
||||
const formatCredits = (value: number | null | undefined) => (
|
||||
(value ?? 0).toLocaleString(undefined, {
|
||||
maximumFractionDigits: 2,
|
||||
minimumFractionDigits: 0,
|
||||
})
|
||||
);
|
||||
|
||||
const formatAmount = (amountMinor?: number | null, currency?: string | null) => {
|
||||
if (amountMinor == null) {
|
||||
return "-";
|
||||
}
|
||||
|
||||
return new Intl.NumberFormat(undefined, {
|
||||
style: "currency",
|
||||
currency: currency || "USD",
|
||||
}).format(amountMinor / 100);
|
||||
};
|
||||
|
||||
const formatDate = (value: string) => (
|
||||
new Date(value).toLocaleString(undefined, {
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
year: "numeric",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})
|
||||
);
|
||||
|
||||
const metricLabels: Record<string, string> = {
|
||||
voice_minutes: "Voice usage",
|
||||
platform_usage: "Platform usage",
|
||||
};
|
||||
|
||||
const formatTitleCase = (value: string | null | undefined) => (
|
||||
value ? value.replaceAll("_", " ").replace(/\b\w/g, (letter) => letter.toUpperCase()) : "-"
|
||||
);
|
||||
|
||||
const getLedgerEntryLabel = (entry: MpsCreditLedgerEntryResponse) => {
|
||||
if (entry.metric_code) {
|
||||
return metricLabels[entry.metric_code] ?? formatTitleCase(entry.metric_code);
|
||||
}
|
||||
|
||||
if (entry.entry_type === "grant") {
|
||||
return "Credit grant";
|
||||
}
|
||||
|
||||
if (entry.entry_type === "purchase") {
|
||||
return "Credit purchase";
|
||||
}
|
||||
|
||||
return formatTitleCase(entry.entry_type);
|
||||
};
|
||||
|
||||
const formatBillableQuantity = (entry: MpsCreditLedgerEntryResponse) => {
|
||||
if (entry.billable_quantity == null || !entry.quantity_unit) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const unit = entry.quantity_unit === "minute" ? "min" : entry.quantity_unit;
|
||||
return `${formatCredits(entry.billable_quantity)} ${unit}`;
|
||||
};
|
||||
|
||||
const getRunHref = (entry: MpsCreditLedgerEntryResponse) => {
|
||||
if (!entry.workflow_id || !entry.workflow_run_id) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return `/workflow/${entry.workflow_id}/run/${entry.workflow_run_id}`;
|
||||
};
|
||||
|
||||
const getPageFromSearchParams = (
|
||||
searchParams: { get: (name: string) => string | null },
|
||||
) => {
|
||||
const pageParam = searchParams.get("page");
|
||||
const page = pageParam ? Number.parseInt(pageParam, 10) : 1;
|
||||
return Number.isFinite(page) && page > 0 ? page : 1;
|
||||
};
|
||||
|
||||
export default function BillingPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const auth = useAuth();
|
||||
const { config } = useAppConfig();
|
||||
const [credits, setCredits] = useState<MpsBillingCreditsResponse | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [refreshing, setRefreshing] = useState(false);
|
||||
const [purchasing, setPurchasing] = useState(false);
|
||||
const [currentPage, setCurrentPage] = useState(
|
||||
() => getPageFromSearchParams(searchParams),
|
||||
);
|
||||
|
||||
const isBillingV2 = credits?.billing_version === "v2";
|
||||
const canPurchaseCredits = isBillingV2 && config?.deploymentMode !== "oss";
|
||||
const totalQuota = credits?.total_quota ?? 0;
|
||||
const remainingCredits = credits?.remaining_credits ?? 0;
|
||||
const usedCredits = credits?.total_credits_used ?? 0;
|
||||
const usagePercent = totalQuota > 0 ? Math.min(100, Math.round((usedCredits / totalQuota) * 100)) : 0;
|
||||
|
||||
const ledgerEntries = useMemo(() => credits?.ledger_entries ?? [], [credits?.ledger_entries]);
|
||||
const ledgerPage = credits?.page ?? currentPage;
|
||||
const ledgerTotalCount = credits?.total_count ?? ledgerEntries.length;
|
||||
const ledgerTotalPages = credits?.total_pages ?? 0;
|
||||
|
||||
const fetchCredits = useCallback(async (
|
||||
page: number,
|
||||
{ silent = false }: { silent?: boolean } = {},
|
||||
) => {
|
||||
if (auth.loading) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!auth.isAuthenticated) {
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (silent) {
|
||||
setRefreshing(true);
|
||||
} else {
|
||||
setLoading(true);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await getBillingCreditsApiV1OrganizationsBillingCreditsGet({
|
||||
query: { page, limit: LEDGER_PAGE_SIZE },
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
throw new Error("Failed to fetch billing credits");
|
||||
}
|
||||
|
||||
setCredits(response.data ?? null);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch billing credits:", error);
|
||||
toast.error("Failed to fetch billing credits");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
setRefreshing(false);
|
||||
}
|
||||
}, [auth.isAuthenticated, auth.loading]);
|
||||
|
||||
useEffect(() => {
|
||||
const nextPage = getPageFromSearchParams(searchParams);
|
||||
setCurrentPage((previousPage) => (
|
||||
previousPage === nextPage ? previousPage : nextPage
|
||||
));
|
||||
}, [searchParams]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchCredits(currentPage);
|
||||
}, [currentPage, fetchCredits]);
|
||||
|
||||
const handleRefresh = () => {
|
||||
fetchCredits(currentPage, { silent: true });
|
||||
};
|
||||
|
||||
const updateUrlPage = useCallback((page: number) => {
|
||||
const newParams = new URLSearchParams(searchParams.toString());
|
||||
if (page > 1) {
|
||||
newParams.set("page", page.toString());
|
||||
} else {
|
||||
newParams.delete("page");
|
||||
}
|
||||
|
||||
const queryString = newParams.toString();
|
||||
router.push(queryString ? `/billing?${queryString}` : "/billing");
|
||||
}, [router, searchParams]);
|
||||
|
||||
const handlePageChange = (page: number) => {
|
||||
const nextPage = Math.max(1, page);
|
||||
setCurrentPage(nextPage);
|
||||
updateUrlPage(nextPage);
|
||||
};
|
||||
|
||||
const handlePurchaseCredits = async () => {
|
||||
if (!canPurchaseCredits) {
|
||||
return;
|
||||
}
|
||||
|
||||
setPurchasing(true);
|
||||
try {
|
||||
const response = await createMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPost();
|
||||
const checkoutUrl = response.data?.checkout_url;
|
||||
if (!checkoutUrl) {
|
||||
throw new Error("Missing checkout URL");
|
||||
}
|
||||
window.location.href = checkoutUrl;
|
||||
} catch (error) {
|
||||
console.error("Failed to create credit purchase URL:", error);
|
||||
toast.error("Failed to open checkout");
|
||||
setPurchasing(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="container mx-auto p-6 space-y-6">
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-9 w-40" />
|
||||
<Skeleton className="h-5 w-96 max-w-full" />
|
||||
</div>
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<Skeleton className="h-36 rounded-lg" />
|
||||
<Skeleton className="h-36 rounded-lg" />
|
||||
</div>
|
||||
<Skeleton className="h-80 rounded-lg" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container mx-auto p-6 space-y-6">
|
||||
<div className="flex flex-col gap-4 md:flex-row md:items-start md:justify-between">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold mb-2">Billing</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Credits, balance, and account usage for your organization.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Button variant="outline" onClick={handleRefresh} disabled={refreshing}>
|
||||
<RefreshCw className={`h-4 w-4 mr-2 ${refreshing ? "animate-spin" : ""}`} />
|
||||
Refresh
|
||||
</Button>
|
||||
{canPurchaseCredits && (
|
||||
<Button onClick={handlePurchaseCredits} disabled={purchasing}>
|
||||
<CreditCard className="h-4 w-4 mr-2" />
|
||||
{purchasing ? "Opening..." : "Add Credits"}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<Card>
|
||||
<CardHeader className="pb-2">
|
||||
<CardDescription>{isBillingV2 ? "Credit balance" : "Credits remaining"}</CardDescription>
|
||||
<CardTitle className="flex items-center gap-2 text-3xl">
|
||||
<CircleDollarSign className="h-6 w-6 text-muted-foreground" />
|
||||
{formatCredits(remainingCredits)}
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">1 credit = 1 cent</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<CardHeader className="pb-2">
|
||||
<CardDescription>Credits used</CardDescription>
|
||||
<CardTitle className="text-3xl">{formatCredits(usedCredits)}</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{isBillingV2 ? "Total ledger debits" : "Current allocation usage"}
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
{isBillingV2 ? (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Credit Ledger</CardTitle>
|
||||
<CardDescription>Recent grants, purchases, and usage debits.</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{ledgerEntries.length > 0 ? (
|
||||
<div className="bg-card border rounded-lg overflow-x-auto shadow-sm">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow className="bg-muted/50">
|
||||
<TableHead>Date</TableHead>
|
||||
<TableHead>Activity</TableHead>
|
||||
<TableHead>Origin</TableHead>
|
||||
<TableHead>Run</TableHead>
|
||||
<TableHead className="text-right">Delta</TableHead>
|
||||
<TableHead className="text-right">Balance</TableHead>
|
||||
<TableHead className="text-right">Amount</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{ledgerEntries.map((entry) => {
|
||||
const delta = entry.credits_delta ?? 0;
|
||||
const runHref = getRunHref(entry);
|
||||
const billableQuantity = formatBillableQuantity(entry);
|
||||
return (
|
||||
<TableRow key={entry.id}>
|
||||
<TableCell>{formatDate(entry.created_at)}</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex flex-col gap-1">
|
||||
<span className="font-medium">{getLedgerEntryLabel(entry)}</span>
|
||||
{billableQuantity && (
|
||||
<span className="text-xs text-muted-foreground">{billableQuantity}</span>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{entry.origin ? (
|
||||
<Badge variant="secondary">{formatTitleCase(entry.origin)}</Badge>
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{entry.workflow_run_id ? (
|
||||
runHref ? (
|
||||
<Link className="font-medium text-primary hover:underline" href={runHref}>
|
||||
#{entry.workflow_run_id}
|
||||
</Link>
|
||||
) : (
|
||||
<span>#{entry.workflow_run_id}</span>
|
||||
)
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className={`text-right font-medium ${delta >= 0 ? "text-green-600" : "text-destructive"}`}>
|
||||
{delta >= 0 ? "+" : ""}
|
||||
{formatCredits(delta)}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">{formatCredits(entry.balance_after)}</TableCell>
|
||||
<TableCell className="text-right">
|
||||
{formatAmount(entry.amount_minor, entry.amount_currency)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
) : (
|
||||
<div className="rounded-lg border border-dashed p-8 text-center text-muted-foreground">
|
||||
No ledger entries yet
|
||||
</div>
|
||||
)}
|
||||
{ledgerTotalPages > 1 && (
|
||||
<div className="flex items-center justify-between mt-6">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Page {ledgerPage} of {ledgerTotalPages} ({ledgerTotalCount} total entries)
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handlePageChange(ledgerPage - 1)}
|
||||
disabled={ledgerPage <= 1 || loading || refreshing}
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4" />
|
||||
Previous
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handlePageChange(ledgerPage + 1)}
|
||||
disabled={ledgerPage >= ledgerTotalPages || loading || refreshing}
|
||||
>
|
||||
Next
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Credit Usage</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<Progress value={usagePercent} />
|
||||
<div className="flex justify-between text-sm text-muted-foreground">
|
||||
<span>{usagePercent}% used</span>
|
||||
<span>{formatCredits(remainingCredits)} of {formatCredits(totalQuota)} remaining</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -12,8 +12,8 @@ import SpinLoader from "@/components/SpinLoader";
|
|||
import { Toaster } from "@/components/ui/sonner";
|
||||
import { AppConfigProvider } from "@/context/AppConfigContext";
|
||||
import { OnboardingProvider } from "@/context/OnboardingContext";
|
||||
import { OrgConfigProvider } from "@/context/OrgConfigContext";
|
||||
import { TelephonyConfigWarningsProvider } from "@/context/TelephonyConfigWarningsContext";
|
||||
import { UserConfigProvider } from "@/context/UserConfigContext";
|
||||
import { AuthProvider } from "@/lib/auth";
|
||||
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ export default function RootLayout({
|
|||
<AuthProvider>
|
||||
<AppConfigProvider>
|
||||
<Suspense fallback={<SpinLoader />}>
|
||||
<UserConfigProvider>
|
||||
<OrgConfigProvider>
|
||||
<TelephonyConfigWarningsProvider>
|
||||
<OnboardingProvider>
|
||||
<PostHogIdentify />
|
||||
|
|
@ -76,7 +76,7 @@ export default function RootLayout({
|
|||
<ChatwootWidget />
|
||||
</OnboardingProvider>
|
||||
</TelephonyConfigWarningsProvider>
|
||||
</UserConfigProvider>
|
||||
</OrgConfigProvider>
|
||||
</Suspense>
|
||||
</AppConfigProvider>
|
||||
</AuthProvider>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { addDays, format, subDays } from 'date-fns';
|
||||
import { Calendar, ChevronLeft, ChevronRight, Download } from 'lucide-react';
|
||||
import { useEffect,useState } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import {
|
||||
getDailyReportApiV1OrganizationsReportsDailyGet,
|
||||
|
|
@ -201,7 +201,9 @@ export default function ReportsPage() {
|
|||
<div className="container mx-auto p-6 space-y-6">
|
||||
{/* Header */}
|
||||
<div className="flex flex-col sm:flex-row justify-between items-start sm:items-center gap-4">
|
||||
<h1 className="text-3xl font-bold">Daily Reports</h1>
|
||||
<div className="space-y-2">
|
||||
<h1 className="text-3xl font-bold">Daily Reports</h1>
|
||||
</div>
|
||||
|
||||
{/* Date Navigation & Workflow Selector */}
|
||||
<div className="flex flex-col sm:flex-row gap-4 items-start sm:items-center">
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import { useCallback, useEffect, useId, useState } from 'react';
|
|||
import TimezoneSelect, { type ITimezoneOption } from 'react-timezone-select';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
import { downloadUsageRunsReportApiV1OrganizationsUsageRunsReportGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet, getPreferencesApiV1OrganizationsPreferencesGet, getUsageHistoryApiV1OrganizationsUsageRunsGet, savePreferencesApiV1OrganizationsPreferencesPut } from '@/client/sdk.gen';
|
||||
import type { DailyUsageBreakdownResponse, MpsCreditsResponse, OrganizationPreferences, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
|
||||
import { downloadUsageRunsReportApiV1OrganizationsUsageRunsReportGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getPreferencesApiV1OrganizationsPreferencesGet, getUsageHistoryApiV1OrganizationsUsageRunsGet, savePreferencesApiV1OrganizationsPreferencesPut } from '@/client/sdk.gen';
|
||||
import type { DailyUsageBreakdownResponse, OrganizationPreferences, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
|
||||
import { CallTypeCell } from '@/components/CallTypeCell';
|
||||
import { DailyUsageTable } from '@/components/DailyUsageTable';
|
||||
import { FilterBuilder } from '@/components/filters/FilterBuilder';
|
||||
|
|
@ -15,7 +15,6 @@ import { MediaPreviewButton, MediaPreviewDialog } from '@/components/MediaPrevie
|
|||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Progress } from '@/components/ui/progress';
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
|
|
@ -39,10 +38,6 @@ export default function UsagePage() {
|
|||
const { organizationPricing } = useUserConfig();
|
||||
const auth = useAuth();
|
||||
|
||||
// MPS credits state
|
||||
const [mpsCredits, setMpsCredits] = useState<MpsCreditsResponse | null>(null);
|
||||
const [isLoadingCredits, setIsLoadingCredits] = useState(true);
|
||||
|
||||
// Usage history state
|
||||
const [usageHistory, setUsageHistory] = useState<UsageHistoryResponse | null>(null);
|
||||
const [isLoadingHistory, setIsLoadingHistory] = useState(false);
|
||||
|
|
@ -78,21 +73,6 @@ export default function UsagePage() {
|
|||
const [preferencesLoading, setPreferencesLoading] = useState(true);
|
||||
const timezoneSelectId = useId(); // Stable ID for react-select to prevent hydration mismatch
|
||||
|
||||
// Fetch MPS credits
|
||||
const fetchMpsCredits = useCallback(async () => {
|
||||
if (!auth.isAuthenticated) return;
|
||||
try {
|
||||
const response = await getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet();
|
||||
if (response.data) {
|
||||
setMpsCredits(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch MPS credits:', error);
|
||||
} finally {
|
||||
setIsLoadingCredits(false);
|
||||
}
|
||||
}, [auth.isAuthenticated]);
|
||||
|
||||
// Translate the FilterBuilder state into the query-param shape the
|
||||
// backend expects. Shared between the listing fetch and the CSV export
|
||||
// so they stay in lockstep.
|
||||
|
|
@ -251,10 +231,9 @@ export default function UsagePage() {
|
|||
// Initial load - fetch when auth becomes available
|
||||
useEffect(() => {
|
||||
if (auth.isAuthenticated) {
|
||||
fetchMpsCredits();
|
||||
fetchUsageHistory(currentPage, appliedFilters);
|
||||
}
|
||||
}, [auth.isAuthenticated, currentPage, appliedFilters, fetchUsageHistory, fetchMpsCredits]);
|
||||
}, [auth.isAuthenticated, currentPage, appliedFilters, fetchUsageHistory]);
|
||||
|
||||
// Fetch daily usage when organizationPricing becomes available
|
||||
useEffect(() => {
|
||||
|
|
@ -428,46 +407,6 @@ export default function UsagePage() {
|
|||
</div>
|
||||
</div>
|
||||
|
||||
{/* MPS Credits Card */}
|
||||
<Card className="mb-6">
|
||||
<CardHeader>
|
||||
<CardTitle>Dograh Model Credits</CardTitle>
|
||||
<CardDescription>
|
||||
These track usage of Dograh models using Dograh Service Keys.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingCredits ? (
|
||||
<div className="animate-pulse space-y-4">
|
||||
<div className="h-4 bg-muted rounded w-1/4"></div>
|
||||
<div className="h-8 bg-muted rounded"></div>
|
||||
<div className="h-4 bg-muted rounded w-1/3"></div>
|
||||
</div>
|
||||
) : mpsCredits ? (
|
||||
<div className="space-y-4">
|
||||
<div className="flex justify-between items-baseline">
|
||||
<div>
|
||||
<p className="text-2xl font-bold">
|
||||
{mpsCredits.total_credits_used.toFixed(2)} <span className="text-lg font-normal text-muted-foreground">/ {mpsCredits.total_quota.toFixed(2)}</span>
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">Credits Used</p>
|
||||
</div>
|
||||
<div className="text-right">
|
||||
<p className="text-lg font-semibold">{mpsCredits.remaining_credits.toFixed(2)}</p>
|
||||
<p className="text-sm text-muted-foreground">Remaining</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{mpsCredits.total_quota > 0 && (
|
||||
<Progress value={(mpsCredits.total_credits_used / mpsCredits.total_quota) * 100} className="h-3" />
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-muted-foreground">No Dograh service keys configured. Set up a service key in your model configuration to see usage.</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Daily Usage Table - Only for paid organizations */}
|
||||
{organizationPricing?.price_per_second_usd && (
|
||||
<div className="mb-6">
|
||||
|
|
@ -535,9 +474,9 @@ export default function UsagePage() {
|
|||
<TableHead className="font-semibold">Disposition</TableHead>
|
||||
<TableHead className="font-semibold">Date</TableHead>
|
||||
<TableHead className="font-semibold text-right">Duration</TableHead>
|
||||
<TableHead className="font-semibold text-right">
|
||||
{organizationPricing?.price_per_second_usd ? 'Cost (USD)' : 'Tokens'}
|
||||
</TableHead>
|
||||
{organizationPricing?.price_per_second_usd && (
|
||||
<TableHead className="font-semibold text-right">Cost (USD)</TableHead>
|
||||
)}
|
||||
<TableHead className="font-semibold">Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
|
|
@ -574,12 +513,14 @@ export default function UsagePage() {
|
|||
<TableCell className="text-right">
|
||||
{formatDuration(run.call_duration_seconds)}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-medium">
|
||||
{organizationPricing?.price_per_second_usd && run.charge_usd !== undefined && run.charge_usd !== null
|
||||
? `$${run.charge_usd.toFixed(2)}`
|
||||
: run.dograh_token_usage.toLocaleString()
|
||||
}
|
||||
</TableCell>
|
||||
{organizationPricing?.price_per_second_usd && (
|
||||
<TableCell className="text-right font-medium">
|
||||
{run.charge_usd !== undefined && run.charge_usd !== null
|
||||
? `$${run.charge_usd.toFixed(2)}`
|
||||
: '-'
|
||||
}
|
||||
</TableCell>
|
||||
)}
|
||||
<TableCell>
|
||||
<MediaPreviewButton
|
||||
recordingUrl={run.recording_url}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -1642,22 +1642,6 @@ export type CurrentUsageResponse = {
|
|||
* Used Dograh Tokens
|
||||
*/
|
||||
used_dograh_tokens: number;
|
||||
/**
|
||||
* Quota Dograh Tokens
|
||||
*/
|
||||
quota_dograh_tokens: number;
|
||||
/**
|
||||
* Percentage Used
|
||||
*/
|
||||
percentage_used: number;
|
||||
/**
|
||||
* Next Refresh Date
|
||||
*/
|
||||
next_refresh_date: string;
|
||||
/**
|
||||
* Quota Enabled
|
||||
*/
|
||||
quota_enabled: boolean;
|
||||
/**
|
||||
* Total Duration Seconds
|
||||
*/
|
||||
|
|
@ -1666,10 +1650,6 @@ export type CurrentUsageResponse = {
|
|||
* Used Amount Usd
|
||||
*/
|
||||
used_amount_usd?: number | null;
|
||||
/**
|
||||
* Quota Amount Usd
|
||||
*/
|
||||
quota_amount_usd?: number | null;
|
||||
/**
|
||||
* Currency
|
||||
*/
|
||||
|
|
@ -3107,6 +3087,165 @@ export type LoginRequest = {
|
|||
password: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* MPSBillingAccountResponse
|
||||
*/
|
||||
export type MpsBillingAccountResponse = {
|
||||
/**
|
||||
* Id
|
||||
*/
|
||||
id: number;
|
||||
/**
|
||||
* Organization Id
|
||||
*/
|
||||
organization_id: number;
|
||||
/**
|
||||
* Billing Mode
|
||||
*/
|
||||
billing_mode: string;
|
||||
/**
|
||||
* Cached Balance Credits
|
||||
*/
|
||||
cached_balance_credits: number;
|
||||
/**
|
||||
* Currency
|
||||
*/
|
||||
currency: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* MPSBillingCreditsResponse
|
||||
*/
|
||||
export type MpsBillingCreditsResponse = {
|
||||
/**
|
||||
* Billing Version
|
||||
*/
|
||||
billing_version: 'legacy' | 'v2';
|
||||
/**
|
||||
* Total Credits Used
|
||||
*/
|
||||
total_credits_used?: number;
|
||||
/**
|
||||
* Remaining Credits
|
||||
*/
|
||||
remaining_credits?: number;
|
||||
/**
|
||||
* Total Quota
|
||||
*/
|
||||
total_quota?: number;
|
||||
account?: MpsBillingAccountResponse | null;
|
||||
/**
|
||||
* Ledger Entries
|
||||
*/
|
||||
ledger_entries?: Array<MpsCreditLedgerEntryResponse>;
|
||||
/**
|
||||
* Total Count
|
||||
*/
|
||||
total_count?: number;
|
||||
/**
|
||||
* Page
|
||||
*/
|
||||
page?: number;
|
||||
/**
|
||||
* Limit
|
||||
*/
|
||||
limit?: number;
|
||||
/**
|
||||
* Total Pages
|
||||
*/
|
||||
total_pages?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* MPSCreditLedgerEntryResponse
|
||||
*/
|
||||
export type MpsCreditLedgerEntryResponse = {
|
||||
/**
|
||||
* Id
|
||||
*/
|
||||
id: number;
|
||||
/**
|
||||
* Entry Type
|
||||
*/
|
||||
entry_type: string;
|
||||
/**
|
||||
* Origin
|
||||
*/
|
||||
origin?: string | null;
|
||||
/**
|
||||
* Credits Delta
|
||||
*/
|
||||
credits_delta: number;
|
||||
/**
|
||||
* Balance After
|
||||
*/
|
||||
balance_after: number;
|
||||
/**
|
||||
* Amount Minor
|
||||
*/
|
||||
amount_minor?: number | null;
|
||||
/**
|
||||
* Amount Currency
|
||||
*/
|
||||
amount_currency?: string | null;
|
||||
/**
|
||||
* Payment Order Id
|
||||
*/
|
||||
payment_order_id?: number | null;
|
||||
/**
|
||||
* Metric Code
|
||||
*/
|
||||
metric_code?: string | null;
|
||||
/**
|
||||
* Correlation Id
|
||||
*/
|
||||
correlation_id?: string | null;
|
||||
/**
|
||||
* Aggregation Key
|
||||
*/
|
||||
aggregation_key?: string | null;
|
||||
/**
|
||||
* Usage Event Id
|
||||
*/
|
||||
usage_event_id?: number | null;
|
||||
/**
|
||||
* Workflow Run Id
|
||||
*/
|
||||
workflow_run_id?: number | null;
|
||||
/**
|
||||
* Workflow Id
|
||||
*/
|
||||
workflow_id?: number | null;
|
||||
/**
|
||||
* Billable Quantity
|
||||
*/
|
||||
billable_quantity?: number | null;
|
||||
/**
|
||||
* Quantity Unit
|
||||
*/
|
||||
quantity_unit?: string | null;
|
||||
/**
|
||||
* Metadata
|
||||
*/
|
||||
metadata?: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
/**
|
||||
* Created At
|
||||
*/
|
||||
created_at: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* MPSCreditPurchaseUrlResponse
|
||||
*/
|
||||
export type MpsCreditPurchaseUrlResponse = {
|
||||
/**
|
||||
* Checkout Url
|
||||
*/
|
||||
checkout_url: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* MPSCreditsResponse
|
||||
*/
|
||||
|
|
@ -3618,6 +3757,43 @@ export type OrganizationAiModelConfigurationV2 = {
|
|||
byok?: ByokaiModelConfiguration | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* OrganizationContextResponse
|
||||
*/
|
||||
export type OrganizationContextResponse = {
|
||||
/**
|
||||
* Organization Id
|
||||
*/
|
||||
organization_id?: number | null;
|
||||
/**
|
||||
* Organization Provider Id
|
||||
*/
|
||||
organization_provider_id?: string | null;
|
||||
model_services: OrganizationModelServicesContext;
|
||||
};
|
||||
|
||||
/**
|
||||
* OrganizationModelServicesContext
|
||||
*/
|
||||
export type OrganizationModelServicesContext = {
|
||||
/**
|
||||
* Config Source
|
||||
*/
|
||||
config_source: 'organization_v2' | 'legacy_user_v1' | 'empty';
|
||||
/**
|
||||
* Has Model Configuration V2
|
||||
*/
|
||||
has_model_configuration_v2: boolean;
|
||||
/**
|
||||
* Managed Service Version
|
||||
*/
|
||||
managed_service_version?: number | null;
|
||||
/**
|
||||
* Uses Managed Service V2
|
||||
*/
|
||||
uses_managed_service_v2: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* OrganizationPreferences
|
||||
*/
|
||||
|
|
@ -9750,6 +9926,45 @@ export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses = {
|
|||
|
||||
export type UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponse = UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses[keyof UnarchiveToolApiV1ToolsToolUuidUnarchivePostResponses];
|
||||
|
||||
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
/**
|
||||
* Authorization
|
||||
*/
|
||||
authorization?: string | null;
|
||||
/**
|
||||
* X-Api-Key
|
||||
*/
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/organizations/context';
|
||||
};
|
||||
|
||||
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetError = GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors[keyof GetCurrentOrganizationContextApiV1OrganizationsContextGetErrors];
|
||||
|
||||
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: OrganizationContextResponse;
|
||||
};
|
||||
|
||||
export type GetCurrentOrganizationContextApiV1OrganizationsContextGetResponse = GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses[keyof GetCurrentOrganizationContextApiV1OrganizationsContextGetResponses];
|
||||
|
||||
export type GetTelephonyProvidersMetadataApiV1OrganizationsTelephonyProvidersMetadataGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
|
|
@ -11269,6 +11484,93 @@ export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses = {
|
|||
|
||||
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponse = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses];
|
||||
|
||||
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
/**
|
||||
* Authorization
|
||||
*/
|
||||
authorization?: string | null;
|
||||
/**
|
||||
* X-Api-Key
|
||||
*/
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: {
|
||||
/**
|
||||
* Page
|
||||
*/
|
||||
page?: number;
|
||||
/**
|
||||
* Limit
|
||||
*/
|
||||
limit?: number;
|
||||
};
|
||||
url: '/api/v1/organizations/billing/credits';
|
||||
};
|
||||
|
||||
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetError = GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors[keyof GetBillingCreditsApiV1OrganizationsBillingCreditsGetErrors];
|
||||
|
||||
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: MpsBillingCreditsResponse;
|
||||
};
|
||||
|
||||
export type GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponse = GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses[keyof GetBillingCreditsApiV1OrganizationsBillingCreditsGetResponses];
|
||||
|
||||
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
/**
|
||||
* Authorization
|
||||
*/
|
||||
authorization?: string | null;
|
||||
/**
|
||||
* X-Api-Key
|
||||
*/
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/organizations/usage/mps-credits/purchase-url';
|
||||
};
|
||||
|
||||
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostError = CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors[keyof CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostErrors];
|
||||
|
||||
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: MpsCreditPurchaseUrlResponse;
|
||||
};
|
||||
|
||||
export type CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponse = CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses[keyof CreateMpsCreditPurchaseUrlApiV1OrganizationsUsageMpsCreditsPurchaseUrlPostResponses];
|
||||
|
||||
export type GetUsageHistoryApiV1OrganizationsUsageRunsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@
|
|||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
|
||||
|
||||
type ModelMode = "dograh" | "byok";
|
||||
type ModelMode = "realtime" | "dograh" | "byok";
|
||||
|
||||
interface DograhDefaults {
|
||||
voices: string[];
|
||||
|
|
@ -125,24 +125,35 @@ function effectiveConfigToLegacyShape(config: Record<string, unknown> | null): R
|
|||
};
|
||||
}
|
||||
|
||||
function emptyByokInitialConfig(): Record<string, unknown> {
|
||||
function emptyByokInitialConfig(isRealtime: boolean): Record<string, unknown> {
|
||||
return {
|
||||
is_realtime: false,
|
||||
is_realtime: isRealtime,
|
||||
};
|
||||
}
|
||||
|
||||
// The v2 editor surfaces realtime ("Speech to Speech") and pipeline (BYOK) as
|
||||
// separate tabs, so each tab gets its own initial config. A tab is pre-filled
|
||||
// only when the saved (or effective) configuration matches that tab's mode;
|
||||
// otherwise it starts empty so the other tab's data does not leak across.
|
||||
function getByokInitialConfig(
|
||||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
wantRealtime: boolean,
|
||||
): Record<string, unknown> {
|
||||
const byokConfiguration = byokConfigToLegacyShape(configuration);
|
||||
if (byokConfiguration) return byokConfiguration;
|
||||
const matchesTab = (config: Record<string, unknown> | null) =>
|
||||
config ? Boolean(config.is_realtime) === wantRealtime : false;
|
||||
|
||||
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
|
||||
return emptyByokInitialConfig();
|
||||
const byokConfiguration = byokConfigToLegacyShape(configuration);
|
||||
if (byokConfiguration) {
|
||||
return matchesTab(byokConfiguration) ? byokConfiguration : emptyByokInitialConfig(wantRealtime);
|
||||
}
|
||||
|
||||
return effectiveConfigToLegacyShape(effectiveConfiguration) || emptyByokInitialConfig();
|
||||
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
|
||||
return emptyByokInitialConfig(wantRealtime);
|
||||
}
|
||||
|
||||
const effective = effectiveConfigToLegacyShape(effectiveConfiguration);
|
||||
return matchesTab(effective) ? (effective as Record<string, unknown>) : emptyByokInitialConfig(wantRealtime);
|
||||
}
|
||||
|
||||
function buildDograhState(
|
||||
|
|
@ -185,10 +196,12 @@ function preferredMode(
|
|||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
): ModelMode {
|
||||
if (configuration?.mode === "dograh" || configuration?.mode === "byok") {
|
||||
return configuration.mode;
|
||||
if (configuration?.mode === "dograh") return "dograh";
|
||||
if (configuration?.mode === "byok") {
|
||||
return asRecord(configuration.byok)?.mode === "realtime" ? "realtime" : "byok";
|
||||
}
|
||||
return isDograhEffectiveConfig(effectiveConfiguration) ? "dograh" : "byok";
|
||||
if (isDograhEffectiveConfig(effectiveConfiguration)) return "dograh";
|
||||
return Boolean(effectiveConfiguration?.is_realtime) ? "realtime" : "byok";
|
||||
}
|
||||
|
||||
function hasRequiredApiKey(
|
||||
|
|
@ -249,7 +262,8 @@ export function AIModelConfigurationV2Editor({
|
|||
speed: defaults.dograh.defaults.speed,
|
||||
language: defaults.dograh.defaults.language,
|
||||
}));
|
||||
const [byokInitialConfig, setByokInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [realtimeInitialConfig, setRealtimeInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [pipelineInitialConfig, setPipelineInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [isSavingDograh, setIsSavingDograh] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
|
|
@ -258,7 +272,8 @@ export function AIModelConfigurationV2Editor({
|
|||
const rawEffectiveConfiguration = asRecord(effectiveConfiguration);
|
||||
setMode(preferredMode(rawConfiguration, rawEffectiveConfiguration));
|
||||
setDograh(buildDograhState(defaults, rawConfiguration, rawEffectiveConfiguration));
|
||||
setByokInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration));
|
||||
setRealtimeInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, true));
|
||||
setPipelineInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration, false));
|
||||
}, [configuration, defaults, effectiveConfiguration]);
|
||||
|
||||
const saveDograhConfiguration = async () => {
|
||||
|
|
@ -322,28 +337,30 @@ export function AIModelConfigurationV2Editor({
|
|||
)}
|
||||
|
||||
<Tabs value={mode} onValueChange={(value) => setMode(value as ModelMode)} className="space-y-6">
|
||||
<TabsList className="grid w-full grid-cols-2">
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsTrigger value="realtime">Speech to Speech</TabsTrigger>
|
||||
<TabsTrigger value="dograh">Dograh</TabsTrigger>
|
||||
<TabsTrigger value="byok">BYOK</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="realtime" className="mt-0">
|
||||
<p className="mb-4 text-sm text-muted-foreground">
|
||||
A single speech-to-speech model handles the conversation in realtime (no separate transcriber or voice). An LLM is still required for variable extraction and QA.
|
||||
</p>
|
||||
<ServiceConfigurationForm
|
||||
key={`realtime-${JSON.stringify(realtimeInitialConfig)}`}
|
||||
mode="global"
|
||||
forceRealtime
|
||||
configurationDefaults={defaultsForByok}
|
||||
initialConfig={realtimeInitialConfig}
|
||||
submitLabel={submitLabel}
|
||||
onSave={saveByokConfiguration}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="dograh" className="mt-0">
|
||||
<div className="rounded-lg border p-5">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2 sm:col-span-2">
|
||||
<Label htmlFor="dograh-api-key">API Key</Label>
|
||||
<div className="relative">
|
||||
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
id="dograh-api-key"
|
||||
className="pl-9"
|
||||
value={dograh.api_key}
|
||||
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
|
||||
placeholder="Enter API key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Voice</Label>
|
||||
<Select value={dograh.voice} onValueChange={(voice) => setDograh({ ...dograh, voice })}>
|
||||
|
|
@ -394,6 +411,20 @@ export function AIModelConfigurationV2Editor({
|
|||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2 sm:col-span-2">
|
||||
<Label htmlFor="dograh-api-key">API Key</Label>
|
||||
<div className="relative">
|
||||
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
id="dograh-api-key"
|
||||
className="pl-9"
|
||||
value={dograh.api_key}
|
||||
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
|
||||
placeholder="Enter API key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button type="button" className="mt-6 w-full" onClick={saveDograhConfiguration} disabled={isSavingDograh}>
|
||||
|
|
@ -405,10 +436,11 @@ export function AIModelConfigurationV2Editor({
|
|||
|
||||
<TabsContent value="byok" className="mt-0">
|
||||
<ServiceConfigurationForm
|
||||
key={JSON.stringify(byokInitialConfig)}
|
||||
key={`byok-${JSON.stringify(pipelineInitialConfig)}`}
|
||||
mode="global"
|
||||
forceRealtime={false}
|
||||
configurationDefaults={defaultsForByok}
|
||||
initialConfig={byokInitialConfig}
|
||||
initialConfig={pipelineInitialConfig}
|
||||
submitLabel={submitLabel}
|
||||
onSave={saveByokConfiguration}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -101,6 +101,13 @@ export interface ServiceConfigurationFormProps {
|
|||
submitLabel?: string;
|
||||
configurationDefaults?: ServiceConfigurationDefaults | null;
|
||||
initialConfig?: Record<string, unknown> | null;
|
||||
/**
|
||||
* When set, locks the realtime/pipeline mode to this value and hides the
|
||||
* in-form toggle. The v2 editor uses this to surface realtime
|
||||
* ("Speech to Speech") and pipeline (BYOK) as separate top-level tabs.
|
||||
* Leave undefined to keep the user-controllable toggle (legacy + overrides).
|
||||
*/
|
||||
forceRealtime?: boolean;
|
||||
}
|
||||
|
||||
function getProviderDisplayName(
|
||||
|
|
@ -130,10 +137,11 @@ export function ServiceConfigurationForm({
|
|||
submitLabel,
|
||||
configurationDefaults,
|
||||
initialConfig,
|
||||
forceRealtime,
|
||||
}: ServiceConfigurationFormProps) {
|
||||
const [apiError, setApiError] = useState<string | null>(null);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
const [isRealtime, setIsRealtime] = useState(false);
|
||||
const [isRealtime, setIsRealtime] = useState(forceRealtime ?? false);
|
||||
const { userConfig } = useUserConfig();
|
||||
const [schemas, setSchemas] = useState<Record<ServiceSegment, Record<string, ProviderSchema>>>({
|
||||
llm: {},
|
||||
|
|
@ -227,9 +235,9 @@ export function ServiceConfigurationForm({
|
|||
realtime: realtimeSchemas,
|
||||
});
|
||||
|
||||
// Restore realtime toggle
|
||||
// Restore realtime toggle (skip when the parent locks the mode)
|
||||
const configData = configSource as Record<string, unknown> | null;
|
||||
if (configData?.is_realtime) {
|
||||
if (forceRealtime === undefined && configData?.is_realtime) {
|
||||
setIsRealtime(true);
|
||||
}
|
||||
|
||||
|
|
@ -867,22 +875,24 @@ export function ServiceConfigurationForm({
|
|||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
{/* Realtime toggle */}
|
||||
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
|
||||
<div>
|
||||
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
|
||||
Realtime Mode
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
|
||||
</p>
|
||||
{/* Realtime toggle — hidden when the parent locks the mode (v2 tabs) */}
|
||||
{forceRealtime === undefined && (
|
||||
<div className="flex items-center justify-between mb-4 p-4 border rounded-lg">
|
||||
<div>
|
||||
<Label htmlFor="realtime-toggle" className="text-sm font-medium">
|
||||
Realtime Mode
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground mt-0.5">
|
||||
Uses a single speech-to-speech model (no separate STT/TTS). An LLM is still required for variable extraction and QA.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="realtime-toggle"
|
||||
checked={isRealtime}
|
||||
onCheckedChange={setIsRealtime}
|
||||
/>
|
||||
</div>
|
||||
<Switch
|
||||
id="realtime-toggle"
|
||||
checked={isRealtime}
|
||||
onCheckedChange={setIsRealtime}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Card>
|
||||
<CardContent className="pt-6">
|
||||
|
|
|
|||
|
|
@ -136,6 +136,11 @@ const NAV_SECTIONS: SidebarNavSection[] = [
|
|||
url: "/usage",
|
||||
icon: TrendingUp,
|
||||
},
|
||||
{
|
||||
title: "Billing",
|
||||
url: "/billing",
|
||||
icon: CircleDollarSign,
|
||||
},
|
||||
{
|
||||
title: "Reports",
|
||||
url: "/reports",
|
||||
|
|
|
|||
192
ui/src/context/OrgConfigContext.tsx
Normal file
192
ui/src/context/OrgConfigContext.tsx
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
'use client';
|
||||
|
||||
import { createContext, ReactNode, useCallback, useContext, useEffect, useRef, useState } from 'react';
|
||||
|
||||
import { client } from '@/client/client.gen';
|
||||
import { getCurrentOrganizationContextApiV1OrganizationsContextGet, getUserConfigurationsApiV1UserConfigurationsUserGet, updateUserConfigurationsApiV1UserConfigurationsUserPut } from '@/client/sdk.gen';
|
||||
import type { OrganizationContextResponse, UserConfigurationRequestResponseSchema } from '@/client/types.gen';
|
||||
import { setupAuthInterceptor } from '@/lib/apiClient';
|
||||
import type { AuthUser } from '@/lib/auth';
|
||||
import { useAuth } from '@/lib/auth';
|
||||
|
||||
interface TeamPermission {
|
||||
id: string;
|
||||
}
|
||||
|
||||
interface OrganizationPricing {
|
||||
price_per_second_usd: number | null;
|
||||
currency: string;
|
||||
billing_enabled: boolean;
|
||||
}
|
||||
|
||||
interface OrgConfigContextType {
|
||||
orgContext: OrganizationContextResponse | null;
|
||||
userConfig: UserConfigurationRequestResponseSchema | null;
|
||||
saveUserConfig: (userConfig: UserConfigurationRequestResponseSchema) => Promise<void>;
|
||||
loading: boolean;
|
||||
error: Error | null;
|
||||
refreshConfig: () => Promise<void>;
|
||||
permissions: TeamPermission[];
|
||||
user: AuthUser | null;
|
||||
organizationPricing: OrganizationPricing | null;
|
||||
}
|
||||
|
||||
const OrgConfigContext = createContext<OrgConfigContextType | null>(null);
|
||||
|
||||
const pricingFromUserConfig = (
|
||||
userConfig: UserConfigurationRequestResponseSchema,
|
||||
): OrganizationPricing | null => {
|
||||
if (!userConfig.organization_pricing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
price_per_second_usd: userConfig.organization_pricing.price_per_second_usd as number | null,
|
||||
currency: (userConfig.organization_pricing.currency as string) || 'USD',
|
||||
billing_enabled: (userConfig.organization_pricing.billing_enabled as boolean) || false,
|
||||
};
|
||||
};
|
||||
|
||||
export function OrgConfigProvider({ children }: { children: ReactNode }) {
|
||||
const [orgContext, setOrgContext] = useState<OrganizationContextResponse | null>(null);
|
||||
const [userConfig, setUserConfig] = useState<UserConfigurationRequestResponseSchema | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const [organizationPricing, setOrganizationPricing] = useState<OrganizationPricing | null>(null);
|
||||
const [permissions, setPermissions] = useState<TeamPermission[]>([]);
|
||||
|
||||
const auth = useAuth();
|
||||
|
||||
const authRef = useRef(auth);
|
||||
authRef.current = auth;
|
||||
|
||||
const hasFetchedConfig = useRef(false);
|
||||
const hasFetchedPermissions = useRef(false);
|
||||
|
||||
if (!auth.loading && auth.isAuthenticated) {
|
||||
setupAuthInterceptor(client, auth.getAccessToken);
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (auth.loading || hasFetchedPermissions.current) {
|
||||
return;
|
||||
}
|
||||
hasFetchedPermissions.current = true;
|
||||
|
||||
const fetchPermissions = async () => {
|
||||
const currentAuth = authRef.current;
|
||||
if (currentAuth.provider === 'stack' && currentAuth.getSelectedTeam && currentAuth.listPermissions) {
|
||||
const selectedTeam = currentAuth.getSelectedTeam();
|
||||
if (selectedTeam) {
|
||||
try {
|
||||
const perms = await currentAuth.listPermissions(selectedTeam);
|
||||
setPermissions(Array.isArray(perms) ? perms : []);
|
||||
} catch {
|
||||
setPermissions([]);
|
||||
}
|
||||
} else {
|
||||
setPermissions([]);
|
||||
}
|
||||
} else {
|
||||
setPermissions([{ id: 'admin' }]);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPermissions();
|
||||
}, [auth.loading, auth.provider]);
|
||||
|
||||
const fetchConfig = useCallback(async () => {
|
||||
const currentAuth = authRef.current;
|
||||
if (!currentAuth.isAuthenticated) {
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const [orgContextResponse, userConfigResponse] = await Promise.all([
|
||||
getCurrentOrganizationContextApiV1OrganizationsContextGet(),
|
||||
getUserConfigurationsApiV1UserConfigurationsUserGet(),
|
||||
]);
|
||||
|
||||
if (orgContextResponse.data) {
|
||||
setOrgContext(orgContextResponse.data);
|
||||
}
|
||||
|
||||
if (userConfigResponse.data) {
|
||||
setUserConfig(userConfigResponse.data);
|
||||
setOrganizationPricing(pricingFromUserConfig(userConfigResponse.data));
|
||||
}
|
||||
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err : new Error('Failed to fetch organization configuration'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (auth.loading || !auth.isAuthenticated || hasFetchedConfig.current) {
|
||||
return;
|
||||
}
|
||||
hasFetchedConfig.current = true;
|
||||
fetchConfig();
|
||||
}, [auth.loading, auth.isAuthenticated, fetchConfig]);
|
||||
|
||||
const saveUserConfig = useCallback(async (userConfigRequest: UserConfigurationRequestResponseSchema) => {
|
||||
if (!authRef.current.isAuthenticated) throw new Error('No authentication available');
|
||||
const response = await updateUserConfigurationsApiV1UserConfigurationsUserPut({
|
||||
body: {
|
||||
...userConfig,
|
||||
...userConfigRequest,
|
||||
} as UserConfigurationRequestResponseSchema,
|
||||
});
|
||||
if (response.error) {
|
||||
let msg = 'Failed to save user configuration';
|
||||
const detail = (response.error as unknown as { detail?: string | { errors: { model: string; message: string }[] } }).detail;
|
||||
if (typeof detail === 'string') {
|
||||
msg = detail;
|
||||
} else if (Array.isArray(detail)) {
|
||||
msg = detail
|
||||
.map((e: { model: string; message: string }) => `${e.model}: ${e.message}`)
|
||||
.join('\n');
|
||||
}
|
||||
throw new Error(msg);
|
||||
}
|
||||
|
||||
if (response.data) {
|
||||
setUserConfig(response.data);
|
||||
setOrganizationPricing(pricingFromUserConfig(response.data));
|
||||
}
|
||||
}, [userConfig]);
|
||||
|
||||
const refreshConfig = useCallback(async () => {
|
||||
await fetchConfig();
|
||||
}, [fetchConfig]);
|
||||
|
||||
return (
|
||||
<OrgConfigContext.Provider
|
||||
value={{
|
||||
orgContext,
|
||||
userConfig,
|
||||
saveUserConfig,
|
||||
loading,
|
||||
error,
|
||||
refreshConfig,
|
||||
permissions,
|
||||
user: auth.user,
|
||||
organizationPricing,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</OrgConfigContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useOrgConfig() {
|
||||
const context = useContext(OrgConfigContext);
|
||||
if (!context) {
|
||||
throw new Error('useOrgConfig must be used within an OrgConfigProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
|
@ -1,205 +1,3 @@
|
|||
'use client';
|
||||
|
||||
import { createContext, ReactNode, useCallback, useContext, useEffect, useRef, useState } from 'react';
|
||||
|
||||
import { client } from '@/client/client.gen';
|
||||
import { getUserConfigurationsApiV1UserConfigurationsUserGet, updateUserConfigurationsApiV1UserConfigurationsUserPut } from '@/client/sdk.gen';
|
||||
import type { UserConfigurationRequestResponseSchema } from '@/client/types.gen';
|
||||
import { setupAuthInterceptor } from '@/lib/apiClient';
|
||||
import type { AuthUser } from '@/lib/auth';
|
||||
import { useAuth } from '@/lib/auth';
|
||||
|
||||
|
||||
interface TeamPermission {
|
||||
id: string;
|
||||
}
|
||||
|
||||
interface OrganizationPricing {
|
||||
price_per_second_usd: number | null;
|
||||
currency: string;
|
||||
billing_enabled: boolean;
|
||||
}
|
||||
|
||||
interface UserConfigContextType {
|
||||
userConfig: UserConfigurationRequestResponseSchema | null;
|
||||
saveUserConfig: (userConfig: UserConfigurationRequestResponseSchema) => Promise<void>;
|
||||
loading: boolean;
|
||||
error: Error | null;
|
||||
refreshConfig: () => Promise<void>;
|
||||
permissions: TeamPermission[];
|
||||
user: AuthUser | null;
|
||||
organizationPricing: OrganizationPricing | null;
|
||||
}
|
||||
|
||||
const UserConfigContext = createContext<UserConfigContextType | null>(null);
|
||||
|
||||
export function UserConfigProvider({ children }: { children: ReactNode }) {
|
||||
const [userConfig, setUserConfig] = useState<UserConfigurationRequestResponseSchema | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const [organizationPricing, setOrganizationPricing] = useState<OrganizationPricing | null>(null);
|
||||
const [permissions, setPermissions] = useState<TeamPermission[]>([]);
|
||||
|
||||
const auth = useAuth();
|
||||
|
||||
// Store auth functions in refs to avoid dependency issues
|
||||
const authRef = useRef(auth);
|
||||
authRef.current = auth;
|
||||
|
||||
// Track initialization
|
||||
const hasFetchedConfig = useRef(false);
|
||||
const hasFetchedPermissions = useRef(false);
|
||||
|
||||
// Register the auth interceptor synchronously during render (not in useEffect)
|
||||
// so it's in place before any child effects fire API calls.
|
||||
// setupAuthInterceptor is idempotent — safe for strict mode double-renders.
|
||||
if (!auth.loading && auth.isAuthenticated) {
|
||||
setupAuthInterceptor(client, auth.getAccessToken);
|
||||
}
|
||||
|
||||
// Fetch permissions once when auth is ready
|
||||
useEffect(() => {
|
||||
if (auth.loading || hasFetchedPermissions.current) {
|
||||
return;
|
||||
}
|
||||
hasFetchedPermissions.current = true;
|
||||
|
||||
const fetchPermissions = async () => {
|
||||
const currentAuth = authRef.current;
|
||||
if (currentAuth.provider === 'stack' && currentAuth.getSelectedTeam && currentAuth.listPermissions) {
|
||||
const selectedTeam = currentAuth.getSelectedTeam();
|
||||
if (selectedTeam) {
|
||||
try {
|
||||
const perms = await currentAuth.listPermissions(selectedTeam);
|
||||
setPermissions(Array.isArray(perms) ? perms : []);
|
||||
} catch {
|
||||
setPermissions([]);
|
||||
}
|
||||
} else {
|
||||
setPermissions([]);
|
||||
}
|
||||
} else {
|
||||
setPermissions([{ id: 'admin' }]);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPermissions();
|
||||
}, [auth.loading, auth.provider]);
|
||||
|
||||
// Fetch user config once when auth is ready
|
||||
useEffect(() => {
|
||||
if (auth.loading || !auth.isAuthenticated || hasFetchedConfig.current) {
|
||||
return;
|
||||
}
|
||||
hasFetchedConfig.current = true;
|
||||
|
||||
const fetchUserConfig = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await getUserConfigurationsApiV1UserConfigurationsUserGet();
|
||||
|
||||
if (response.data) {
|
||||
setUserConfig(response.data);
|
||||
if (response.data.organization_pricing) {
|
||||
setOrganizationPricing({
|
||||
price_per_second_usd: response.data.organization_pricing.price_per_second_usd as number | null,
|
||||
currency: response.data.organization_pricing.currency as string || 'USD',
|
||||
billing_enabled: response.data.organization_pricing.billing_enabled as boolean || false
|
||||
});
|
||||
} else {
|
||||
setOrganizationPricing(null);
|
||||
}
|
||||
}
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err : new Error('Failed to fetch user configuration'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchUserConfig();
|
||||
}, [auth.loading, auth.isAuthenticated]);
|
||||
|
||||
const saveUserConfig = useCallback(async (userConfigRequest: UserConfigurationRequestResponseSchema) => {
|
||||
if (!authRef.current.isAuthenticated) throw new Error('No authentication available');
|
||||
const response = await updateUserConfigurationsApiV1UserConfigurationsUserPut({
|
||||
body: {
|
||||
...userConfig,
|
||||
...userConfigRequest
|
||||
} as UserConfigurationRequestResponseSchema,
|
||||
});
|
||||
if (response.error) {
|
||||
let msg = 'Failed to save user configuration';
|
||||
const detail = (response.error as unknown as { detail?: string | { errors: { model: string; message: string }[] } }).detail;
|
||||
if (typeof detail === 'string') {
|
||||
msg = detail;
|
||||
} else if (Array.isArray(detail)) {
|
||||
msg = detail
|
||||
.map((e: { model: string; message: string }) => `${e.model}: ${e.message}`)
|
||||
.join('\n');
|
||||
}
|
||||
throw new Error(msg);
|
||||
}
|
||||
setUserConfig(response.data!);
|
||||
|
||||
if (response.data?.organization_pricing) {
|
||||
setOrganizationPricing({
|
||||
price_per_second_usd: response.data.organization_pricing.price_per_second_usd as number | null,
|
||||
currency: response.data.organization_pricing.currency as string || 'USD',
|
||||
billing_enabled: response.data.organization_pricing.billing_enabled as boolean || false
|
||||
});
|
||||
}
|
||||
}, [userConfig]);
|
||||
|
||||
const refreshConfig = useCallback(async () => {
|
||||
const currentAuth = authRef.current;
|
||||
if (!currentAuth.isAuthenticated) return;
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await getUserConfigurationsApiV1UserConfigurationsUserGet();
|
||||
|
||||
if (response.data) {
|
||||
setUserConfig(response.data);
|
||||
if (response.data.organization_pricing) {
|
||||
setOrganizationPricing({
|
||||
price_per_second_usd: response.data.organization_pricing.price_per_second_usd as number | null,
|
||||
currency: response.data.organization_pricing.currency as string || 'USD',
|
||||
billing_enabled: response.data.organization_pricing.billing_enabled as boolean || false
|
||||
});
|
||||
}
|
||||
}
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err : new Error('Failed to fetch user configuration'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<UserConfigContext.Provider
|
||||
value={{
|
||||
userConfig,
|
||||
saveUserConfig,
|
||||
loading,
|
||||
error,
|
||||
refreshConfig,
|
||||
permissions,
|
||||
user: auth.user,
|
||||
organizationPricing,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</UserConfigContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useUserConfig() {
|
||||
const context = useContext(UserConfigContext);
|
||||
if (!context) {
|
||||
throw new Error('useUserConfig must be used within a UserConfigProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
export { OrgConfigProvider as UserConfigProvider, useOrgConfig as useUserConfig } from './OrgConfigContext';
|
||||
|
|
|
|||
|
|
@ -2,18 +2,33 @@
|
|||
* Extract a human-readable message from a backend error response.
|
||||
*
|
||||
* The generated API client returns `{ error }` on failure (it does not throw),
|
||||
* and FastAPI shapes that error as either `{ detail: string }` (HTTPException)
|
||||
* or `{ detail: [{ msg, loc, ... }] }` (422 validation). This normalizes both
|
||||
* to a single string so it can be rendered or thrown directly — never pass the
|
||||
* raw `detail` to React, as the 422 array crashes rendering.
|
||||
* and FastAPI shapes that error as `{ detail: string }`, `{ detail:
|
||||
* [{ msg, loc, ... }] }`, or backend validation arrays like `{ detail:
|
||||
* [{ model, message }] }`. This normalizes those to a single string so it can
|
||||
* be rendered or thrown directly.
|
||||
*/
|
||||
export function detailFromError(err: unknown, fallback = "Request failed"): string {
|
||||
if (typeof err === "string") return err;
|
||||
const e = err as { detail?: unknown };
|
||||
if (typeof e?.detail === "string") return e.detail;
|
||||
if (Array.isArray(e?.detail) && e.detail.length > 0) {
|
||||
const first = e.detail[0] as { msg?: string };
|
||||
if (first?.msg) return first.msg;
|
||||
const messages = e.detail
|
||||
.map((item) => {
|
||||
if (typeof item === "string") return item;
|
||||
if (!item || typeof item !== "object") return null;
|
||||
const detail = item as { message?: unknown; msg?: unknown; model?: unknown };
|
||||
const message = typeof detail.message === "string"
|
||||
? detail.message
|
||||
: typeof detail.msg === "string"
|
||||
? detail.msg
|
||||
: null;
|
||||
if (!message) return null;
|
||||
return typeof detail.model === "string" && detail.model
|
||||
? `${detail.model}: ${message}`
|
||||
: message;
|
||||
})
|
||||
.filter((message): message is string => Boolean(message));
|
||||
if (messages.length > 0) return messages.join("\n");
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue