mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: billing and credit management v2 (#429)
* feat: use mps generated correlation ID * chore: update pipecat submodule * feat: add credit purchase URL * feat: carve out billing page and show credit ledger * feat: deprecate dograh based quota tracking * fix: remove cost calculation from dograh codebase * fix: create mps account on migrate to v2 * chore: update pipecat
This commit is contained in:
parent
97d7103480
commit
1f1149f4d5
80 changed files with 3335 additions and 2057 deletions
|
|
@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
|
|||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# 1) Create the `quota_type` enum *before* we add the column that references it.
|
||||
|
|
@ -34,7 +37,12 @@ def upgrade() -> None:
|
|||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("period_start", sa.DateTime(), nullable=False),
|
||||
sa.Column("period_end", sa.DateTime(), nullable=False),
|
||||
sa.Column("quota_dograh_tokens", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"quota_dograh_tokens",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
sa.Column("used_dograh_tokens", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
|
|
@ -63,7 +71,11 @@ def upgrade() -> None:
|
|||
op.add_column(
|
||||
"organizations",
|
||||
sa.Column(
|
||||
"quota_type", quota_type_enum, nullable=False, server_default="monthly"
|
||||
"quota_type",
|
||||
quota_type_enum,
|
||||
nullable=False,
|
||||
server_default="monthly",
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
|
@ -73,6 +85,7 @@ def upgrade() -> None:
|
|||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
|
@ -82,10 +95,17 @@ def upgrade() -> None:
|
|||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("LEAST(EXTRACT(DAY FROM CURRENT_DATE)::int, 28)"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"organizations", sa.Column("quota_start_date", sa.DateTime(), nullable=True)
|
||||
"organizations",
|
||||
sa.Column(
|
||||
"quota_start_date",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"organizations",
|
||||
|
|
@ -94,6 +114,7 @@ def upgrade() -> None:
|
|||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
|
|||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
|
|
@ -26,7 +29,12 @@ def upgrade() -> None:
|
|||
)
|
||||
op.add_column(
|
||||
"organization_usage_cycles",
|
||||
sa.Column("quota_amount_usd", sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
"quota_amount_usd",
|
||||
sa.Float(),
|
||||
nullable=True,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
|
|||
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
|
||||
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
|
||||
|
||||
class CampaignClient(BaseDBClient):
|
||||
|
|
@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
conditions = [
|
||||
WorkflowRunModel.is_completed.is_(True),
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"]
|
||||
WorkflowRunModel.usage_info["call_duration_seconds"]
|
||||
.as_string()
|
||||
.isnot(None),
|
||||
]
|
||||
|
|
@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient):
|
|||
WorkflowRunModel.initial_context,
|
||||
WorkflowRunModel.gathered_context,
|
||||
WorkflowRunModel.cost_info,
|
||||
WorkflowRunModel.usage_info,
|
||||
WorkflowRunModel.public_access_token,
|
||||
)
|
||||
.where(*conditions)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class DBClient(
|
|||
- UserClient: handles user and user configuration operations
|
||||
- OrganizationClient: handles organization operations
|
||||
- OrganizationConfigurationClient: handles organization configuration operations
|
||||
- OrganizationUsageClient: handles organization usage and quota operations
|
||||
- OrganizationUsageClient: handles organization usage reporting aggregates
|
||||
- IntegrationClient: handles integration operations
|
||||
- WorkflowTemplateClient: handles workflow template operations
|
||||
- CampaignClient: handles campaign operations
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def get_workflow_run_order_clause(
|
|||
"""
|
||||
# Determine sort column
|
||||
if sort_by == "duration":
|
||||
sort_column = WorkflowRunModel.cost_info.op("->>")(
|
||||
sort_column = WorkflowRunModel.usage_info.op("->>")(
|
||||
"call_duration_seconds"
|
||||
).cast(Float)
|
||||
else:
|
||||
|
|
@ -43,7 +43,7 @@ def get_workflow_run_order_clause(
|
|||
ATTRIBUTE_FIELD_MAPPING = {
|
||||
"dateRange": "created_at",
|
||||
"dispositionCode": "gathered_context.mapped_call_disposition",
|
||||
"duration": "cost_info.call_duration_seconds",
|
||||
"duration": "usage_info.call_duration_seconds",
|
||||
"status": "is_completed",
|
||||
"tokenUsage": "cost_info.total_cost_usd",
|
||||
"runId": "id",
|
||||
|
|
@ -208,7 +208,7 @@ def apply_workflow_run_filters(
|
|||
min_val = value.get("min")
|
||||
max_val = value.get("max")
|
||||
|
||||
if field == "cost_info.call_duration_seconds":
|
||||
if field == "usage_info.call_duration_seconds":
|
||||
# Use ->> operator for compatibility with all PostgreSQL versions
|
||||
# (subscript [] only works in PostgreSQL 14+)
|
||||
duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")(
|
||||
|
|
|
|||
|
|
@ -97,22 +97,44 @@ class OrganizationModel(Base):
|
|||
provider_id = Column(String, unique=True, index=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Quota fields
|
||||
# Deprecated: MPS owns quota and credit ledger state.
|
||||
quota_type = Column(
|
||||
Enum("monthly", "annual", name="quota_type"),
|
||||
nullable=False,
|
||||
default="monthly",
|
||||
server_default=text("'monthly'::quota_type"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default=text("0"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_reset_day = Column(
|
||||
Integer, nullable=False, default=1, server_default=text("1")
|
||||
) # 1-28, only for monthly
|
||||
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=1,
|
||||
server_default=text("1"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_start_date = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
server_default=text("false"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
|
||||
price_per_second_usd = Column(Float, nullable=True)
|
||||
|
|
@ -593,8 +615,9 @@ class WorkflowRunTextSessionModel(Base):
|
|||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
cycle.
|
||||
This model is used to track reporting aggregates for an organization for a given
|
||||
usage cycle. Quota fields on this model are deprecated; MPS owns quota and
|
||||
credit ledger state.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_usage_cycles"
|
||||
|
|
@ -603,14 +626,24 @@ class OrganizationUsageCycleModel(Base):
|
|||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
period_start = Column(DateTime(timezone=True), nullable=False)
|
||||
period_end = Column(DateTime(timezone=True), nullable=False)
|
||||
quota_dograh_tokens = Column(Integer, nullable=False)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
used_dograh_tokens = Column(Float, nullable=False, default=0)
|
||||
total_duration_seconds = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
)
|
||||
# New USD tracking fields
|
||||
used_amount_usd = Column(Float, nullable=True, default=0)
|
||||
quota_amount_usd = Column(Float, nullable=True)
|
||||
quota_amount_usd = Column(
|
||||
Float,
|
||||
nullable=True,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ from api.db.models import (
|
|||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class OrganizationUsageClient(BaseDBClient):
|
||||
"""Client for managing organization usage and quota operations."""
|
||||
"""Client for managing organization usage reporting aggregates."""
|
||||
|
||||
async def get_or_create_current_cycle(
|
||||
self, organization_id: int, session=None
|
||||
|
|
@ -49,14 +49,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
self, organization_id: int, session, commit: bool
|
||||
) -> OrganizationUsageCycleModel:
|
||||
"""Internal implementation for get_or_create_current_cycle."""
|
||||
# Get organization to determine quota type
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# Calculate current period
|
||||
period_start, period_end = self._calculate_current_period(org)
|
||||
period_start, period_end = self._calculate_current_period()
|
||||
|
||||
# Try to get existing cycle
|
||||
cycle_result = await session.execute(
|
||||
|
|
@ -78,7 +71,8 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
organization_id=organization_id,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
quota_dograh_tokens=org.quota_dograh_tokens,
|
||||
# Deprecated non-null column retained for historical schema compatibility.
|
||||
quota_dograh_tokens=0,
|
||||
)
|
||||
# Handle concurrent inserts gracefully
|
||||
stmt = stmt.on_conflict_do_nothing(
|
||||
|
|
@ -102,95 +96,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
return cycle_result.scalar_one()
|
||||
|
||||
async def check_and_reserve_quota(
|
||||
self, organization_id: int, estimated_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if organization has sufficient quota and optionally reserve tokens.
|
||||
Returns True if quota is available, False otherwise.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if not org or not org.quota_enabled:
|
||||
# No quota enforcement if not enabled
|
||||
return True
|
||||
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Atomic check and update with row-level lock
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(
|
||||
and_(
|
||||
OrganizationUsageCycleModel.id == cycle.id,
|
||||
OrganizationUsageCycleModel.used_dograh_tokens
|
||||
+ estimated_tokens
|
||||
<= OrganizationUsageCycleModel.quota_dograh_tokens,
|
||||
)
|
||||
)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
|
||||
cycle_locked = result.scalar_one_or_none()
|
||||
if cycle_locked:
|
||||
# Update the usage atomically
|
||||
cycle_locked.used_dograh_tokens += estimated_tokens
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def update_usage_after_run(
|
||||
self,
|
||||
organization_id: int,
|
||||
actual_tokens: float,
|
||||
duration_seconds: float = 0,
|
||||
charge_usd: float | None = None,
|
||||
) -> None:
|
||||
"""Update usage after a workflow run completes with actual token count and duration.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Acquire a row-level lock for atomic update
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(OrganizationUsageCycleModel.id == cycle.id)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
cycle_locked = result.scalar_one()
|
||||
|
||||
# Update usage atomically
|
||||
cycle_locked.used_dograh_tokens += actual_tokens
|
||||
cycle_locked.total_duration_seconds += int(round(duration_seconds))
|
||||
|
||||
# Update USD amount if provided
|
||||
if charge_usd is not None:
|
||||
if cycle_locked.used_amount_usd is None:
|
||||
cycle_locked.used_amount_usd = 0
|
||||
cycle_locked.used_amount_usd += charge_usd
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def get_current_usage(self, organization_id: int) -> dict:
|
||||
"""Get current period usage information."""
|
||||
"""Get current reporting-period usage information."""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
|
|
@ -201,42 +109,19 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Calculate next refresh date
|
||||
if org.quota_type == "monthly":
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
else: # annual
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
|
||||
result = {
|
||||
"period_start": cycle.period_start.isoformat(),
|
||||
"period_end": cycle.period_end.isoformat(),
|
||||
"used_dograh_tokens": cycle.used_dograh_tokens,
|
||||
"quota_dograh_tokens": cycle.quota_dograh_tokens,
|
||||
"percentage_used": (
|
||||
round(
|
||||
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
|
||||
)
|
||||
if cycle.quota_dograh_tokens > 0
|
||||
else 0
|
||||
),
|
||||
"next_refresh_date": next_refresh.date().isoformat(),
|
||||
"quota_enabled": org.quota_enabled,
|
||||
"total_duration_seconds": cycle.total_duration_seconds,
|
||||
}
|
||||
|
||||
# Add USD fields if organization has pricing
|
||||
if org.price_per_second_usd is not None:
|
||||
result["used_amount_usd"] = cycle.used_amount_usd or 0
|
||||
result["quota_amount_usd"] = cycle.quota_amount_usd
|
||||
result["currency"] = "USD"
|
||||
result["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Calculate percentage based on USD if available
|
||||
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
|
||||
result["percentage_used"] = round(
|
||||
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_usage_history(
|
||||
|
|
@ -256,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
WorkflowRunModel.usage_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
|
@ -309,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
total_tokens = 0
|
||||
total_duration_seconds = 0
|
||||
for run in runs:
|
||||
if run.cost_info:
|
||||
# Try to get dograh_token_usage first (new format)
|
||||
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
|
||||
# If not present, calculate from total_cost_usd (old format)
|
||||
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
|
||||
dograh_tokens = round(
|
||||
float(run.cost_info["total_cost_usd"]) * 100, 2
|
||||
)
|
||||
# Get call duration
|
||||
call_duration = run.cost_info.get("call_duration_seconds", 0)
|
||||
else:
|
||||
dograh_tokens = 0
|
||||
call_duration = 0
|
||||
dograh_tokens = 0
|
||||
call_duration = (run.usage_info or {}).get("call_duration_seconds", 0)
|
||||
total_tokens += dograh_tokens
|
||||
total_duration_seconds += int(round(call_duration))
|
||||
|
||||
|
|
@ -395,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
WorkflowRunModel.initial_context,
|
||||
WorkflowRunModel.gathered_context,
|
||||
WorkflowRunModel.cost_info,
|
||||
WorkflowRunModel.usage_info,
|
||||
WorkflowRunModel.public_access_token,
|
||||
)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
WorkflowRunModel.usage_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
|
@ -473,11 +348,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
if config_obj and config_obj.configuration:
|
||||
user_config = EffectiveAIModelConfiguration.model_validate(
|
||||
effective_config = EffectiveAIModelConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = user_config.timezone
|
||||
if effective_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = effective_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
try:
|
||||
|
|
@ -496,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
select(
|
||||
date_expr.label("date"),
|
||||
func.sum(
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
|
||||
WorkflowRunModel.usage_info["call_duration_seconds"].as_float()
|
||||
).label("total_seconds"),
|
||||
func.count(WorkflowRunModel.id).label("call_count"),
|
||||
)
|
||||
|
|
@ -545,83 +420,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"currency": "USD",
|
||||
}
|
||||
|
||||
async def update_organization_quota(
|
||||
self,
|
||||
organization_id: int,
|
||||
quota_type: str,
|
||||
quota_dograh_tokens: int,
|
||||
quota_reset_day: Optional[int] = None,
|
||||
quota_start_date: Optional[datetime] = None,
|
||||
) -> OrganizationModel:
|
||||
"""Update organization quota settings."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = result.scalar_one()
|
||||
|
||||
org.quota_type = quota_type
|
||||
org.quota_dograh_tokens = quota_dograh_tokens
|
||||
org.quota_enabled = True
|
||||
|
||||
if quota_type == "monthly" and quota_reset_day:
|
||||
org.quota_reset_day = quota_reset_day
|
||||
elif quota_type == "annual" and quota_start_date:
|
||||
org.quota_start_date = quota_start_date
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
def _calculate_current_period(
|
||||
self, org: OrganizationModel
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current billing period based on organization settings."""
|
||||
def _calculate_current_period(self) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current calendar-month reporting period."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if org.quota_type == "monthly":
|
||||
# Find the start of the current billing month
|
||||
reset_day = org.quota_reset_day
|
||||
|
||||
# Handle month boundaries
|
||||
if now.day >= reset_day:
|
||||
period_start = now.replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
else:
|
||||
# Previous month
|
||||
period_start = (now - relativedelta(months=1)).replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
# End is one month later minus 1 second
|
||||
period_end = (
|
||||
period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
)
|
||||
|
||||
else: # annual
|
||||
if not org.quota_start_date:
|
||||
# Default to calendar year
|
||||
period_start = now.replace(
|
||||
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
else:
|
||||
# Find current annual period
|
||||
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
|
||||
years_diff = now.year - start_date.year
|
||||
|
||||
# Adjust for whether we've passed the anniversary
|
||||
if now.month < start_date.month or (
|
||||
now.month == start_date.month and now.day < start_date.day
|
||||
):
|
||||
years_diff -= 1
|
||||
|
||||
period_start = start_date + relativedelta(years=years_diff)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
|
||||
return period_start, period_end
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class UserClient(BaseDBClient):
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from api.db.models import (
|
|||
)
|
||||
from api.enums import CallType, StorageBackend
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
|
||||
|
||||
class WorkflowRunClient(BaseDBClient):
|
||||
|
|
@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ async def search_chunks(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
|
|
@ -392,17 +392,17 @@ async def search_chunks(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
if effective_config.embeddings:
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.constants import (
|
||||
DEFAULT_CAMPAIGN_RETRY_CONFIG,
|
||||
DEFAULT_ORG_CONCURRENCY_LIMIT,
|
||||
DEPLOYMENT_MODE,
|
||||
)
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
|
|
@ -55,6 +59,11 @@ from api.services.configuration.registry import (
|
|||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
|
|
@ -129,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
|
|||
telnyx_missing_webhook_public_key_count: int
|
||||
|
||||
|
||||
@router.get("/context", response_model=OrganizationContextResponse)
|
||||
async def get_current_organization_context(user: UserModel = Depends(get_user)):
|
||||
"""Return organization-scoped configuration signals owned by Dograh."""
|
||||
return await get_organization_context(user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/telephony-providers/metadata",
|
||||
response_model=TelephonyProvidersMetadataResponse,
|
||||
|
|
@ -349,6 +364,23 @@ async def migrate_model_configuration_v2(
|
|||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize MPS billing v2 account for organization {}: {}",
|
||||
organization_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to initialize MPS billing v2 account",
|
||||
)
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.reports import generate_usage_runs_report_csv
|
||||
from api.utils.artifacts import artifact_url
|
||||
|
|
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
|
|||
period_start: str
|
||||
period_end: str
|
||||
used_dograh_tokens: float
|
||||
quota_dograh_tokens: int
|
||||
percentage_used: float
|
||||
next_refresh_date: str
|
||||
quota_enabled: bool
|
||||
total_duration_seconds: int
|
||||
# New USD fields
|
||||
used_amount_usd: Optional[float] = None
|
||||
quota_amount_usd: Optional[float] = None
|
||||
currency: Optional[str] = None
|
||||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
|
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
|
|||
total_quota: float
|
||||
|
||||
|
||||
class MPSCreditPurchaseUrlResponse(BaseModel):
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class MPSBillingAccountResponse(BaseModel):
|
||||
id: int
|
||||
organization_id: int
|
||||
billing_mode: str
|
||||
cached_balance_credits: float
|
||||
currency: str
|
||||
|
||||
|
||||
class MPSCreditLedgerEntryResponse(BaseModel):
|
||||
id: int
|
||||
entry_type: str
|
||||
origin: Optional[str] = None
|
||||
credits_delta: float
|
||||
balance_after: float
|
||||
amount_minor: Optional[int] = None
|
||||
amount_currency: Optional[str] = None
|
||||
payment_order_id: Optional[int] = None
|
||||
metric_code: Optional[str] = None
|
||||
correlation_id: Optional[str] = None
|
||||
aggregation_key: Optional[str] = None
|
||||
usage_event_id: Optional[int] = None
|
||||
workflow_run_id: Optional[int] = None
|
||||
workflow_id: Optional[int] = None
|
||||
billable_quantity: Optional[float] = None
|
||||
quantity_unit: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class MPSBillingCreditsResponse(BaseModel):
|
||||
billing_version: Literal["legacy", "v2"]
|
||||
total_credits_used: float = 0.0
|
||||
remaining_credits: float = 0.0
|
||||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
total_count: int = 0
|
||||
page: int = 1
|
||||
limit: int = 50
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
def _optional_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
|
|||
|
||||
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
|
||||
async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
||||
"""Get current billing period usage for the user's organization."""
|
||||
"""Get current reporting-period usage for the user's organization."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
|
||||
|
|
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _get_mps_billing_account_status(
|
||||
user: UserModel, organization_id: int
|
||||
) -> Optional[dict]:
|
||||
return await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
|
||||
|
||||
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = float(usage.get("total_credits_used", 0.0))
|
||||
total_remaining = float(usage.get("remaining_credits", 0.0))
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="legacy",
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
total_count = int(ledger.get("total_count") or len(ledger_entries))
|
||||
response_limit = int(ledger.get("limit") or limit)
|
||||
total_pages = int(
|
||||
ledger.get("total_pages")
|
||||
or ((total_count + response_limit - 1) // response_limit)
|
||||
)
|
||||
workflow_ids_by_run_id: dict[int, int] = {}
|
||||
workflow_run_ids = {
|
||||
workflow_run_id
|
||||
for entry in ledger_entries
|
||||
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
|
||||
is not None
|
||||
}
|
||||
for workflow_run_id in workflow_run_ids:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if (
|
||||
workflow_run
|
||||
and workflow_run.workflow
|
||||
and workflow_run.workflow.organization_id == organization_id
|
||||
):
|
||||
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
|
||||
|
||||
balance = float(account.get("cached_balance_credits") or 0.0)
|
||||
total_debits = sum(
|
||||
abs(float(entry.get("credits_delta") or 0.0))
|
||||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
if ledger.get("total_debits_credits") is not None:
|
||||
total_debits = float(ledger["total_debits_credits"])
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
total_credits_used=total_debits,
|
||||
remaining_credits=balance,
|
||||
total_quota=balance + total_debits,
|
||||
account=MPSBillingAccountResponse(
|
||||
id=int(account["id"]),
|
||||
organization_id=int(account["organization_id"]),
|
||||
billing_mode=str(account["billing_mode"]),
|
||||
cached_balance_credits=balance,
|
||||
currency=str(account.get("currency") or "USD"),
|
||||
),
|
||||
ledger_entries=[
|
||||
MPSCreditLedgerEntryResponse(
|
||||
id=int(entry["id"]),
|
||||
entry_type=str(entry["entry_type"]),
|
||||
origin=entry.get("origin"),
|
||||
credits_delta=float(entry.get("credits_delta") or 0.0),
|
||||
balance_after=float(entry.get("balance_after") or 0.0),
|
||||
amount_minor=entry.get("amount_minor"),
|
||||
amount_currency=entry.get("amount_currency"),
|
||||
payment_order_id=entry.get("payment_order_id"),
|
||||
metric_code=entry.get("metric_code"),
|
||||
correlation_id=entry.get("correlation_id"),
|
||||
aggregation_key=entry.get("aggregation_key"),
|
||||
usage_event_id=_optional_int(entry.get("usage_event_id")),
|
||||
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
|
||||
workflow_id=workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None,
|
||||
billable_quantity=float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None,
|
||||
quantity_unit=entry.get("quantity_unit"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
total_count=total_count,
|
||||
page=int(ledger.get("page") or page),
|
||||
limit=response_limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch billing credits: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/mps-credits/purchase-url",
|
||||
response_model=MPSCreditPurchaseUrlResponse,
|
||||
)
|
||||
async def create_mps_credit_purchase_url(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credit purchases are not available in OSS mode",
|
||||
)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
assert organization_id is not None
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
"Credit purchases are available only for organizations using billing v2"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
session = await mps_service_key_client.create_credit_purchase_url(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
|
||||
billing_details={
|
||||
"source": "dograh_billing",
|
||||
"dograh_user_id": str(user.id),
|
||||
"dograh_provider_id": str(user.provider_id),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to create credit purchase URL",
|
||||
)
|
||||
|
||||
checkout_url = session.get("checkout_url")
|
||||
if not checkout_url:
|
||||
logger.error(f"MPS checkout session response missing checkout_url: {session}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="MPS checkout session response missing checkout_url",
|
||||
)
|
||||
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
FILTERS_DESCRIPTION = """\
|
||||
JSON-encoded array of filter objects. Each object has the shape:
|
||||
|
||||
|
|
|
|||
|
|
@ -41,12 +41,15 @@ from api.services.configuration.resolve import (
|
|||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.reports import generate_workflow_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.run_usage_response import (
|
||||
format_public_cost_info,
|
||||
format_public_usage_info,
|
||||
)
|
||||
from api.services.workflow.trigger_paths import (
|
||||
TriggerPathIssue,
|
||||
ensure_trigger_paths,
|
||||
|
|
@ -1053,13 +1056,15 @@ async def update_workflow(
|
|||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
effective_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
effective_config,
|
||||
)
|
||||
effective = resolve_effective_config(
|
||||
effective_config, enriched_overrides
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
|
|
@ -1264,22 +1269,7 @@ async def get_workflow_run(
|
|||
"transcript_public_url": artifact_url(public_access_token, "transcript"),
|
||||
"recording_public_url": artifact_url(public_access_token, "recording"),
|
||||
"public_access_token": public_access_token,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info and "dograh_token_usage" in run.cost_info
|
||||
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds"))
|
||||
)
|
||||
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
"created_at": run.created_at,
|
||||
"definition_id": run.definition_id,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
DograhEmbeddingsConfiguration,
|
||||
DograhLLMService,
|
||||
|
|
@ -23,6 +23,29 @@ DOGRAH_DEFAULT_VOICE = "default"
|
|||
DOGRAH_DEFAULT_LANGUAGE = "multi"
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
managed_service_version: int | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
||||
|
||||
class DograhManagedAIModelConfiguration(BaseModel):
|
||||
api_key: str
|
||||
voice: str = DOGRAH_DEFAULT_VOICE
|
||||
|
|
@ -160,6 +183,7 @@ def _compile_dograh_configuration(
|
|||
model="default",
|
||||
),
|
||||
is_realtime=False,
|
||||
managed_service_version=2,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,33 +0,0 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
|
@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
|
@ -110,6 +111,19 @@ async def get_user(
|
|||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization.id,
|
||||
created_by=str(stack_user["id"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize hosted MPS billing account for "
|
||||
"organization {}",
|
||||
organization.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
|
|
@ -232,7 +246,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
|
|
@ -250,7 +264,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
|
|
@ -285,8 +299,8 @@ async def create_user_configuration_with_mps_key(
|
|||
"model": "default",
|
||||
},
|
||||
}
|
||||
user_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return user_config
|
||||
effective_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return effective_config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to get MPS service key: {response.status_code} - {response.text}"
|
||||
|
|
|
|||
|
|
@ -21,10 +21,10 @@ from api.schemas.ai_model_configuration import (
|
|||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from groq import Groq
|
|||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
from api.schemas.ai_model_configuration import (
|
||||
EffectiveAIModelConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
|
|
@ -75,21 +75,21 @@ class UserConfigurationValidator:
|
|||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The rules are simple:
|
|||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
base_url: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
if default_headers:
|
||||
client_kwargs["default_headers"] = default_headers
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
|
|
|
|||
98
api/services/managed_model_services.py
Normal file
98
api/services/managed_model_services.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
||||
def uses_managed_model_services_v2(
|
||||
ai_model_config: EffectiveAIModelConfiguration | None,
|
||||
) -> bool:
|
||||
if (
|
||||
ai_model_config is None
|
||||
or getattr(ai_model_config, "managed_service_version", None) != 2
|
||||
):
|
||||
return False
|
||||
|
||||
return any(
|
||||
_is_dograh_service(getattr(ai_model_config, section_name, None))
|
||||
for section_name in ("llm", "tts", "stt", "embeddings")
|
||||
)
|
||||
|
||||
|
||||
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
|
||||
if not initial_context:
|
||||
return None
|
||||
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
|
||||
if correlation_id is None:
|
||||
return None
|
||||
return str(correlation_id)
|
||||
|
||||
|
||||
async def ensure_mps_correlation_id(
|
||||
*,
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
workflow_run_id: int,
|
||||
initial_context: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
existing = get_mps_correlation_id(initial_context)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
service_key = _get_dograh_service_api_key(ai_model_config)
|
||||
if not service_key:
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires a Dograh service key before the run starts."
|
||||
)
|
||||
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
correlation_id = response.get("correlation_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("MPS correlation-id response did not include correlation_id")
|
||||
|
||||
correlation_id = str(correlation_id)
|
||||
logger.info(
|
||||
"Minted MPS correlation id {} for workflow run {}",
|
||||
correlation_id,
|
||||
workflow_run_id,
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
service = getattr(ai_model_config, section_name, None)
|
||||
if not _is_dograh_service(service):
|
||||
continue
|
||||
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if isinstance(api_key, str) and api_key:
|
||||
return api_key
|
||||
|
||||
return None
|
||||
23
api/services/mps_billing.py
Normal file
23
api/services/mps_billing.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
async def ensure_hosted_mps_billing_account_v2(
|
||||
organization_id: int,
|
||||
*,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Ensure hosted orgs have an MPS billing v2 account.
|
||||
|
||||
OSS deployments use legacy per-key quota accounting and do not create MPS
|
||||
billing accounts.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return None
|
||||
|
||||
return await mps_service_key_client.ensure_billing_account_v2(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
|
|||
Service keys are stored and managed entirely in MPS, not in the local database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -353,6 +354,234 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def create_credit_purchase_url(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
return_url: Optional[str] = None,
|
||||
billing_details: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create a short-lived MPS checkout URL for adding organization credits."""
|
||||
payload = {
|
||||
"created_by": created_by,
|
||||
"return_url": return_url,
|
||||
"billing_details": billing_details or {},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create MPS credit purchase URL: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create MPS credit purchase URL: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get the MPS v2 billing account balance and recent credit ledger."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"page": page, "limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS credit ledger: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS credit ledger: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_billing_account_status(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Get an existing MPS v2 billing account without creating one."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS billing account status: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS billing account status: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def ensure_billing_account_v2(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create or return the MPS v2 billing account for an organization."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to ensure MPS billing account v2: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to ensure MPS billing account v2: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
service_key: str,
|
||||
workflow_run_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Mint a server-generated correlation ID for managed model services."""
|
||||
payload: dict[str, int] = {}
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {service_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create correlation ID: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create correlation ID: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def report_platform_usage(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
correlation_id: Optional[str] = None,
|
||||
duration_seconds: Optional[float] = None,
|
||||
workflow_run_id: int | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
max_attempts: int = 3,
|
||||
) -> dict:
|
||||
"""Report hosted Dograh platform usage for a completed workflow run."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise ValueError("OSS deployments must not report platform usage to MPS")
|
||||
if not correlation_id and duration_seconds is None:
|
||||
raise ValueError(
|
||||
"Platform usage reports require correlation_id or duration_seconds"
|
||||
)
|
||||
|
||||
payload: dict = {
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if correlation_id:
|
||||
payload["correlation_id"] = correlation_id
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
max_attempts = max(1, max_attempts)
|
||||
last_response: httpx.Response | None = None
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
response = await client.post(
|
||||
(
|
||||
f"{self.base_url}/api/v1/billing/accounts/"
|
||||
f"{organization_id}/platform-usage"
|
||||
),
|
||||
json=payload,
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
last_response = response
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
should_retry = (
|
||||
response.status_code == 409
|
||||
and "usage_not_ready" in response.text
|
||||
and attempt < max_attempts
|
||||
)
|
||||
if should_retry:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
"Failed to report platform usage: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to report platform usage: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
raise httpx.HTTPStatusError(
|
||||
"Failed to report platform usage",
|
||||
request=last_response.request,
|
||||
response=last_response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
|
|
|
|||
50
api/services/organization_context.py
Normal file
50
api/services/organization_context.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
|
||||
class OrganizationModelServicesContext(BaseModel):
|
||||
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
has_model_configuration_v2: bool
|
||||
managed_service_version: Optional[int] = None
|
||||
uses_managed_service_v2: bool
|
||||
|
||||
|
||||
class OrganizationContextResponse(BaseModel):
|
||||
organization_id: Optional[int] = None
|
||||
organization_provider_id: Optional[str] = None
|
||||
model_services: OrganizationModelServicesContext
|
||||
|
||||
|
||||
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
|
||||
organization_id = user.selected_organization_id
|
||||
organization = (
|
||||
await db_client.get_organization_by_id(organization_id)
|
||||
if organization_id
|
||||
else None
|
||||
)
|
||||
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
managed_service_version = resolved.effective.managed_service_version
|
||||
|
||||
return OrganizationContextResponse(
|
||||
organization_id=organization_id,
|
||||
organization_provider_id=organization.provider_id if organization else None,
|
||||
model_services=OrganizationModelServicesContext(
|
||||
config_source=resolved.source,
|
||||
has_model_configuration_v2=resolved.source == "organization_v2",
|
||||
managed_service_version=managed_service_version,
|
||||
uses_managed_service_v2=(
|
||||
resolved.source == "organization_v2" and managed_service_version == 2
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
@ -162,15 +162,13 @@ async def run_pipeline_telephony(
|
|||
workflow_id: Workflow being executed.
|
||||
workflow_run_id: Workflow run row.
|
||||
user_id: Owner of the workflow.
|
||||
call_id: Provider call identifier (stored in cost_info for billing).
|
||||
call_id: Provider call identifier.
|
||||
transport_kwargs: Provider-specific kwargs forwarded to the transport
|
||||
factory (e.g. stream_sid + call_sid for Twilio).
|
||||
"""
|
||||
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
|
||||
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if workflow:
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
|
@ -340,7 +338,7 @@ async def _run_pipeline(
|
|||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
merged_call_context_vars = dict(workflow_run.initial_context or {})
|
||||
# If there is some extra call_context_vars, fold them in. Persistence
|
||||
# happens once below, after runtime_configuration is also resolved.
|
||||
if call_context_vars:
|
||||
|
|
@ -398,6 +396,19 @@ async def _run_pipeline(
|
|||
else:
|
||||
user_config = resolved_user_config
|
||||
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=merged_call_context_vars,
|
||||
)
|
||||
if mps_correlation_id:
|
||||
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
|
|
@ -409,11 +420,23 @@ async def _run_pipeline(
|
|||
# Realtime services don't implement run_inference, so create a
|
||||
# separate text LLM for variable extraction and other out-of-band
|
||||
# inference calls.
|
||||
inference_llm = create_llm_service(user_config)
|
||||
inference_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
stt = create_stt_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
keyterms=keyterms,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
tts = create_tts_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = None
|
||||
|
||||
# Stamp the providers/models actually resolved for this run onto
|
||||
|
|
@ -695,7 +718,10 @@ async def _run_pipeline(
|
|||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
voicemail_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
|
|
|
|||
|
|
@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
|||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
user_config,
|
||||
audio_config: "AudioConfig",
|
||||
keyterms: list[str] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
):
|
||||
"""Create and return appropriate STT service based on user configuration
|
||||
|
||||
|
|
@ -160,6 +163,7 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
|
|
@ -286,7 +290,9 @@ def create_stt_service(
|
|||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
def create_tts_service(
|
||||
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
|
||||
):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
|
|
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
|
|
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
|
|||
model: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
correlation_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
|
|
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
|
|||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
|
|
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
def create_llm_service(user_config, correlation_id: str | None = None):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
model = user_config.llm.model
|
||||
|
|
@ -880,4 +889,10 @@ def create_llm_service(user_config):
|
|||
elif provider == ServiceProviders.SARVAM.value:
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
return create_llm_service_from_provider(
|
||||
provider,
|
||||
model,
|
||||
api_key,
|
||||
correlation_id=correlation_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
# Pricing Module
|
||||
|
||||
This module contains pricing models and registries for different AI services used in workflow cost calculations.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
pricing/
|
||||
├── __init__.py # Main module exports
|
||||
├── models.py # Base pricing model classes
|
||||
├── llm.py # LLM pricing configurations
|
||||
├── tts.py # TTS pricing configurations
|
||||
├── stt.py # STT pricing configurations
|
||||
├── registry.py # Combined pricing registry
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Pricing Models
|
||||
|
||||
### TokenPricingModel
|
||||
Used for LLM services that charge based on tokens:
|
||||
- `prompt_token_price`: Cost per prompt token
|
||||
- `completion_token_price`: Cost per completion token
|
||||
- `cache_read_discount`: Discount for cache read tokens (default 50%)
|
||||
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
|
||||
|
||||
### CharacterPricingModel
|
||||
Used for TTS services that charge based on character count:
|
||||
- `character_price`: Cost per character
|
||||
|
||||
### TimePricingModel
|
||||
Used for STT services that charge based on time:
|
||||
- `second_price`: Cost per second
|
||||
|
||||
## Adding New Pricing
|
||||
|
||||
### Adding a New LLM Model
|
||||
Edit `llm.py` and add the model to the appropriate provider:
|
||||
|
||||
```python
|
||||
ServiceProviders.OPENAI: {
|
||||
"new-model": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000,
|
||||
completion_token_price=Decimal("8.00") / 1000000,
|
||||
),
|
||||
# ... existing models
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a New Provider
|
||||
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
|
||||
2. The registry will automatically include them
|
||||
|
||||
### Adding a New Service Type
|
||||
1. Create a new pricing file (e.g., `image.py`)
|
||||
2. Define the pricing models
|
||||
3. Import and add to `registry.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The pricing registry is automatically imported and used by the cost calculator:
|
||||
|
||||
```python
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.workflow.cost_calculator import cost_calculator
|
||||
|
||||
# The cost calculator uses the pricing registry automatically
|
||||
result = cost_calculator.calculate_total_cost(usage_info)
|
||||
```
|
||||
|
||||
## Maintenance
|
||||
|
||||
- Update pricing when providers change their rates
|
||||
- All prices should use `Decimal` for precision
|
||||
- Include comments with current pricing from provider documentation
|
||||
- Test changes with existing test suite
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""
|
||||
Pricing module for workflow cost calculation.
|
||||
|
||||
This module contains pricing models and registries for different AI services.
|
||||
"""
|
||||
|
||||
from .registry import PRICING_REGISTRY
|
||||
|
||||
__all__ = ["PRICING_REGISTRY"]
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
"""
|
||||
Cost Calculator for Workflow Runs
|
||||
|
||||
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
|
||||
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
|
||||
|
||||
Features:
|
||||
- Token-based pricing for LLM services with cache optimization support
|
||||
- Character-based pricing for TTS services
|
||||
- Time-based pricing for STT services
|
||||
- Configurable pricing models that can be updated
|
||||
- Support for multiple providers and models
|
||||
- Automatic provider inference from model names
|
||||
- JSON serialization support for database storage
|
||||
|
||||
Usage:
|
||||
from api.tasks.cost_calculator import cost_calculator
|
||||
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"processor_name|||gpt-4o": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"total_tokens": 1500,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"processor_name|||aura-2-helena-en": 2000 # character count
|
||||
}
|
||||
}
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
print(f"Total cost: ${cost_breakdown['total']:.6f}")
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.pricing.models import (
|
||||
PricingModel,
|
||||
)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""Main cost calculator class"""
|
||||
|
||||
def __init__(self, pricing_registry: Dict = None):
|
||||
self.pricing_registry = pricing_registry or PRICING_REGISTRY
|
||||
|
||||
def get_pricing_model(
|
||||
self, service_type: str, provider: str, model: str
|
||||
) -> Optional[PricingModel]:
|
||||
"""Get pricing model for a specific service, provider, and model"""
|
||||
try:
|
||||
service_pricing = self.pricing_registry.get(service_type, {})
|
||||
|
||||
# Try to get pricing for the specific provider
|
||||
provider_pricing = service_pricing.get(provider, {})
|
||||
pricing_model = provider_pricing.get(model) or provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
if pricing_model:
|
||||
return pricing_model
|
||||
|
||||
# If not found, try the "default" provider for this service type
|
||||
default_provider_pricing = service_pricing.get("default", {})
|
||||
return default_provider_pricing.get(model) or default_provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
except (KeyError, AttributeError):
|
||||
return None
|
||||
|
||||
def calculate_llm_cost(
|
||||
self, provider: str, model: str, usage: Dict[str, int]
|
||||
) -> Decimal:
|
||||
"""Calculate cost for LLM usage"""
|
||||
pricing_model = self.get_pricing_model("llm", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(usage)
|
||||
|
||||
def calculate_tts_cost(
|
||||
self, provider: str, model: str, character_count: int
|
||||
) -> Decimal:
|
||||
"""Calculate cost for TTS usage"""
|
||||
pricing_model = self.get_pricing_model("tts", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(character_count)
|
||||
|
||||
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT usage"""
|
||||
pricing_model = self.get_pricing_model("stt", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(seconds)
|
||||
|
||||
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
|
||||
llm_cost_total = Decimal("0")
|
||||
tts_cost_total = Decimal("0")
|
||||
stt_cost_total = Decimal("0")
|
||||
|
||||
# Calculate LLM costs
|
||||
llm_usage = usage_info.get("llm", {})
|
||||
for key, usage in llm_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Try to determine provider from processor name or model
|
||||
provider = self._infer_provider_from_model(model, "llm")
|
||||
cost = self.calculate_llm_cost(provider, model, usage)
|
||||
llm_cost_total += cost
|
||||
|
||||
# Calculate TTS costs
|
||||
tts_usage = usage_info.get("tts", {})
|
||||
for key, character_count in tts_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Handle the case where model is "None" - infer from processor
|
||||
if model.lower() in ["none", "null", ""]:
|
||||
provider = self._infer_provider_from_processor(processor, "tts")
|
||||
model = "default" # Use default model for the provider
|
||||
else:
|
||||
provider = self._infer_provider_from_model(model, "tts")
|
||||
cost = self.calculate_tts_cost(provider, model, character_count)
|
||||
tts_cost_total += cost
|
||||
|
||||
# Calculate STT costs from explicit stt usage
|
||||
stt_usage = usage_info.get("stt", {})
|
||||
for key, seconds in stt_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
provider = self._infer_provider_from_model(model, "stt")
|
||||
cost = self.calculate_stt_cost(provider, model, seconds)
|
||||
stt_cost_total += cost
|
||||
|
||||
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
|
||||
|
||||
return {
|
||||
"llm_cost": float(llm_cost_total),
|
||||
"tts_cost": float(tts_cost_total),
|
||||
"stt_cost": float(stt_cost_total),
|
||||
"total": float(total_cost),
|
||||
}
|
||||
|
||||
def _parse_key(self, key) -> Tuple[str, str]:
|
||||
"""Parse key which is in format 'processor|||model'"""
|
||||
if isinstance(key, str) and "|||" in key:
|
||||
parts = key.split("|||", 1)
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
# Fallback for backwards compatibility or malformed keys
|
||||
return str(key), "unknown"
|
||||
|
||||
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
|
||||
"""Infer provider from model name"""
|
||||
if not model:
|
||||
return "unknown"
|
||||
|
||||
model_lower = model.lower()
|
||||
|
||||
# OpenAI models
|
||||
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq models
|
||||
if any(keyword in model_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Elevenlabs models
|
||||
if any(keyword in model_lower for keyword in ["eleven"]):
|
||||
return ServiceProviders.ELEVENLABS
|
||||
|
||||
# Deepgram models
|
||||
if any(
|
||||
keyword in model_lower
|
||||
for keyword in ["deepgram", "nova", "phonecall", "general"]
|
||||
):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
|
||||
"""Infer provider from processor name"""
|
||||
if not processor:
|
||||
return "unknown"
|
||||
|
||||
processor_lower = processor.lower()
|
||||
|
||||
# OpenAI processors
|
||||
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq processors
|
||||
if any(keyword in processor_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Deepgram processors
|
||||
if any(keyword in processor_lower for keyword in ["deepgram"]):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def update_pricing(
|
||||
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
|
||||
):
|
||||
"""Update pricing for a specific service/provider/model combination"""
|
||||
if service_type not in self.pricing_registry:
|
||||
self.pricing_registry[service_type] = {}
|
||||
if provider not in self.pricing_registry[service_type]:
|
||||
self.pricing_registry[service_type][provider] = {}
|
||||
self.pricing_registry[service_type][provider][model] = pricing_model
|
||||
|
||||
|
||||
# Global cost calculator instance
|
||||
cost_calculator = CostCalculator()
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
"""
|
||||
LLM pricing models for different providers.
|
||||
|
||||
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TokenPricingModel
|
||||
|
||||
# LLM pricing registry
|
||||
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-3.5-turbo": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
|
||||
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
|
||||
),
|
||||
"gpt-4": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
|
||||
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
|
||||
),
|
||||
"gpt-4.1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
|
||||
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-nano": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
|
||||
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
),
|
||||
"gpt-4.5-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
|
||||
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
|
||||
completion_token_price=Decimal("10.00")
|
||||
/ 1000000, # $10.00 per 1M tokens - FIXED
|
||||
),
|
||||
"gpt-4o-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
|
||||
),
|
||||
"gpt-4o-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"o1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
|
||||
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
|
||||
),
|
||||
"o1-pro": TokenPricingModel(
|
||||
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
|
||||
),
|
||||
"o1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o3": TokenPricingModel(
|
||||
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
|
||||
),
|
||||
"o3-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o4-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"computer-use-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
|
||||
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
|
||||
),
|
||||
"gpt-image-1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
|
||||
),
|
||||
"codex-mini-latest": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
|
||||
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
|
||||
),
|
||||
# Transcription models
|
||||
"gpt-4o-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
|
||||
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
),
|
||||
# TTS models with token-based pricing
|
||||
"gpt-4o-mini-tts": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("0")
|
||||
/ 1000000, # No completion tokens for TTS
|
||||
),
|
||||
},
|
||||
ServiceProviders.GROQ: {
|
||||
"llama-3.3-70b-versatile": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
|
||||
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
|
||||
),
|
||||
"deepseek-r1-distill-llama-70b": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
|
||||
completion_token_price=Decimal("0.00079") / 1000,
|
||||
),
|
||||
},
|
||||
ServiceProviders.AZURE: {
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("8.80")
|
||||
/ 1000000, # $1.60 per 1M tokens if using data zone
|
||||
)
|
||||
},
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
Base pricing models for different service types.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CostType(Enum):
|
||||
LLM_TOKENS = "llm_tokens"
|
||||
TTS_CHARACTERS = "tts_characters"
|
||||
STT_SECONDS = "stt_seconds"
|
||||
|
||||
|
||||
class PricingModel:
|
||||
"""Base class for pricing models"""
|
||||
|
||||
def calculate_cost(self, usage: Any) -> Decimal:
|
||||
"""Calculate cost based on usage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenPricingModel(PricingModel):
|
||||
"""Pricing model for token-based services (LLM)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_price: Decimal,
|
||||
completion_token_price: Decimal,
|
||||
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
|
||||
cache_creation_multiplier: Decimal = Decimal(
|
||||
"1.25"
|
||||
), # 25% premium for cache creation
|
||||
):
|
||||
self.prompt_token_price = prompt_token_price
|
||||
self.completion_token_price = completion_token_price
|
||||
self.cache_read_discount = cache_read_discount
|
||||
self.cache_creation_multiplier = cache_creation_multiplier
|
||||
|
||||
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
|
||||
"""Calculate cost for LLM token usage"""
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
|
||||
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
|
||||
|
||||
# Base cost
|
||||
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
|
||||
completion_cost = Decimal(completion_tokens) * self.completion_token_price
|
||||
|
||||
# Cache adjustments
|
||||
cache_read_savings = (
|
||||
Decimal(cache_read_tokens)
|
||||
* self.prompt_token_price
|
||||
* self.cache_read_discount
|
||||
)
|
||||
cache_creation_premium = (
|
||||
Decimal(cache_creation_tokens)
|
||||
* self.prompt_token_price
|
||||
* (self.cache_creation_multiplier - 1)
|
||||
)
|
||||
|
||||
total_cost = (
|
||||
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
|
||||
)
|
||||
return max(total_cost, Decimal("0")) # Ensure non-negative
|
||||
|
||||
|
||||
class CharacterPricingModel(PricingModel):
|
||||
"""Pricing model for character-based services (TTS)"""
|
||||
|
||||
def __init__(self, character_price: Decimal):
|
||||
self.character_price = character_price
|
||||
|
||||
def calculate_cost(self, character_count: int) -> Decimal:
|
||||
"""Calculate cost for TTS character usage"""
|
||||
return Decimal(character_count) * self.character_price
|
||||
|
||||
|
||||
class TimePricingModel(PricingModel):
|
||||
"""Pricing model for time-based services (STT)"""
|
||||
|
||||
def __init__(self, second_price: Decimal):
|
||||
self.second_price = second_price
|
||||
|
||||
def calculate_cost(self, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT time usage"""
|
||||
return Decimal(str(seconds)) * self.second_price
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Main pricing registry that combines all service type pricing models.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
||||
# Combined pricing registry for all service types
|
||||
PRICING_REGISTRY: Dict = {
|
||||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
STT (Speech-to-Text) pricing models for different providers.
|
||||
|
||||
Prices are per second for STT services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TimePricingModel
|
||||
|
||||
# STT pricing registry
|
||||
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
|
||||
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
|
||||
"default": TimePricingModel(Decimal("0.0077") / 60),
|
||||
},
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
|
||||
"default": TimePricingModel(Decimal("0.015") / 60),
|
||||
},
|
||||
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""
|
||||
TTS (Text-to-Speech) pricing models for different providers.
|
||||
|
||||
Prices are per character for TTS services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import CharacterPricingModel
|
||||
|
||||
# TTS pricing registry
|
||||
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
},
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
|
||||
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
},
|
||||
ServiceProviders.ELEVENLABS: {
|
||||
# 6400 usd per 250*1e6 characters
|
||||
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
|
||||
},
|
||||
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
|
||||
}
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider_for_run
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _get_pricing_organization(workflow_run):
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
organization_id = getattr(workflow, "organization_id", None)
|
||||
if organization_id is None and workflow and workflow.user:
|
||||
organization_id = workflow.user.selected_organization_id
|
||||
if organization_id is None:
|
||||
return None
|
||||
return await db_client.get_organization_by_id(organization_id)
|
||||
|
||||
|
||||
async def _build_usage_cost_snapshot(
|
||||
usage_info: dict | None,
|
||||
*,
|
||||
workflow_run=None,
|
||||
include_telephony_cost: bool = False,
|
||||
organization=None,
|
||||
calculated_at: str | None = None,
|
||||
) -> dict | None:
|
||||
if not usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return None
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
|
||||
if include_telephony_cost and workflow_run is not None:
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
total_cost_usd = Decimal(str(cost_breakdown["total"]))
|
||||
dograh_tokens = float(total_cost_usd * Decimal("100"))
|
||||
|
||||
if organization is None and workflow_run is not None:
|
||||
organization = await _get_pricing_organization(workflow_run)
|
||||
|
||||
charge_usd = None
|
||||
if organization and organization.price_per_second_usd:
|
||||
duration_seconds = usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = float(
|
||||
Decimal(str(duration_seconds))
|
||||
* Decimal(str(organization.price_per_second_usd))
|
||||
)
|
||||
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(total_cost_usd),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": calculated_at
|
||||
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
|
||||
}
|
||||
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = organization.price_per_second_usd
|
||||
|
||||
return cost_info
|
||||
|
||||
|
||||
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
|
||||
cost_info = await _build_usage_cost_snapshot(
|
||||
workflow_run.usage_info,
|
||||
workflow_run=workflow_run,
|
||||
include_telephony_cost=True,
|
||||
calculated_at=workflow_run.created_at.isoformat(),
|
||||
)
|
||||
if cost_info is None:
|
||||
return None
|
||||
return {
|
||||
**(workflow_run.cost_info or {}),
|
||||
**cost_info,
|
||||
}
|
||||
|
||||
|
||||
async def save_workflow_run_cost_info(
|
||||
workflow_run_id: int, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
|
||||
async def apply_workflow_run_usage_to_organization(
|
||||
workflow_run, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
|
||||
|
||||
async def apply_usage_delta_to_organization(
|
||||
workflow_run, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return None
|
||||
|
||||
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
|
||||
if cost_info is None:
|
||||
return None
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
return cost_info
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
try:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
await save_workflow_run_cost_info(workflow_run_id, cost_info)
|
||||
|
||||
try:
|
||||
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
|
||||
except Exception as e:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if org:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to update organization usage: {e}")
|
||||
# Don't fail the whole cost calculation if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
|
|
@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
for run in runs:
|
||||
initial = run.initial_context or {}
|
||||
gathered = run.gathered_context or {}
|
||||
cost = run.cost_info or {}
|
||||
usage = run.usage_info or {}
|
||||
|
||||
call_tags = gathered.get("call_tags", [])
|
||||
if isinstance(call_tags, list):
|
||||
|
|
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
run.created_at.isoformat() if run.created_at else "",
|
||||
initial.get("phone_number", ""),
|
||||
gathered.get("mapped_call_disposition", ""),
|
||||
cost.get("call_duration_seconds", ""),
|
||||
usage.get("call_duration_seconds", ""),
|
||||
]
|
||||
|
||||
extracted = gathered.get("extracted_variables", {})
|
||||
|
|
|
|||
|
|
@ -66,34 +66,6 @@ async def handle_vonage_events(
|
|||
logger.error(f"[run {workflow_run_id}] Workflow run not found")
|
||||
return {"status": "error", "message": "Workflow run not found"}
|
||||
|
||||
# For a completed call that includes cost info, capture it immediately
|
||||
if event_data.get("status") == "completed":
|
||||
# Vonage sometimes includes price info in the webhook
|
||||
if "price" in event_data or "rate" in event_data:
|
||||
try:
|
||||
if workflow_run.cost_info:
|
||||
# Store immediate cost info if available
|
||||
cost_info = workflow_run.cost_info.copy()
|
||||
if "price" in event_data:
|
||||
cost_info["vonage_webhook_price"] = float(event_data["price"])
|
||||
if "rate" in event_data:
|
||||
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
|
||||
if "duration" in event_data:
|
||||
cost_info["vonage_webhook_duration"] = int(
|
||||
event_data["duration"]
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, cost_info=cost_info
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
|
||||
)
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import asyncio
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
|
|
@ -382,6 +383,9 @@ class PipecatEngine:
|
|||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
correlation_id=self._call_context_vars.get(
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
),
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
41
api/services/workflow/run_usage_response.py
Normal file
41
api/services/workflow/run_usage_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
||||
|
||||
def format_public_cost_info(
|
||||
cost_info: dict | None, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
"""Return the legacy response shape without doing local cost accounting."""
|
||||
duration = None
|
||||
if usage_info and usage_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(usage_info.get("call_duration_seconds") or 0))
|
||||
elif cost_info and cost_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(cost_info.get("call_duration_seconds") or 0))
|
||||
|
||||
dograh_token_usage = 0
|
||||
if cost_info:
|
||||
if "dograh_token_usage" in cost_info:
|
||||
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
|
||||
elif "total_cost_usd" in cost_info:
|
||||
dograh_token_usage = round(
|
||||
float(cost_info.get("total_cost_usd", 0)) * 100, 2
|
||||
)
|
||||
|
||||
if duration is None and dograh_token_usage == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"dograh_token_usage": dograh_token_usage,
|
||||
"call_duration_seconds": duration,
|
||||
}
|
||||
|
|
@ -421,7 +421,19 @@ async def execute_text_chat_pending_turn(
|
|||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
base_initial_context = dict(workflow_run.initial_context or {})
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=base_initial_context,
|
||||
)
|
||||
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
|
|
@ -429,9 +441,15 @@ async def execute_text_chat_pending_turn(
|
|||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
**base_initial_context,
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
if mps_correlation_id:
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
|
|
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
|
|||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ async def _perform_retrieval(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -272,11 +278,20 @@ async def _perform_retrieval(
|
|||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
111
api/services/workflow_run_billing.py
Normal file
111
api/services/workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Workflow-run billing hooks.
|
||||
|
||||
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
|
||||
For hosted deployments, Dograh reports completed platform usage to MPS.
|
||||
When a server-minted MPS correlation id exists, MPS uses model-service usage
|
||||
as the canonical duration. Otherwise Dograh reports the completed run duration.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.services.managed_model_services import get_mps_correlation_id
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
def _workflow_run_organization_id(workflow_run) -> int | None:
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
return getattr(workflow, "organization_id", None)
|
||||
|
||||
|
||||
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
|
||||
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
|
||||
duration = usage_info.get("call_duration_seconds")
|
||||
try:
|
||||
duration_seconds = float(duration)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return duration_seconds if duration_seconds > 0 else None
|
||||
|
||||
|
||||
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
|
||||
account = await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id
|
||||
)
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def report_workflow_run_platform_usage(workflow_run) -> None:
|
||||
"""Report hosted platform usage for a completed workflow run to MPS."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return
|
||||
|
||||
if not getattr(workflow_run, "is_completed", False):
|
||||
return
|
||||
|
||||
organization_id = _workflow_run_organization_id(workflow_run)
|
||||
if organization_id is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no organization_id",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
correlation_id = get_mps_correlation_id(
|
||||
getattr(workflow_run, "initial_context", None)
|
||||
)
|
||||
duration_seconds = (
|
||||
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
|
||||
)
|
||||
if not correlation_id and duration_seconds is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no billable duration",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if not await _organization_uses_mps_billing_v2(organization_id):
|
||||
return
|
||||
|
||||
result = await mps_service_key_client.report_platform_usage(
|
||||
organization_id=organization_id,
|
||||
correlation_id=correlation_id,
|
||||
duration_seconds=duration_seconds,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": getattr(workflow_run, "workflow_id", None),
|
||||
"duration_source": (
|
||||
"mps_correlation" if correlation_id else "dograh_usage_info"
|
||||
),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Reported platform usage for workflow run {} to MPS: {}",
|
||||
workflow_run.id,
|
||||
result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
|
||||
"""Load a completed workflow run and report platform usage to MPS."""
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Skipping platform usage report: workflow run {} not found",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
|
@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import (
|
|||
)
|
||||
from api.tasks.knowledge_base_processing import process_knowledge_base_document
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
process_workflow_completion,
|
||||
upload_voicemail_audio_to_s3,
|
||||
)
|
||||
from api.tasks.s3_upload import upload_voicemail_audio_to_s3
|
||||
from api.tasks.workflow_completion import process_workflow_completion
|
||||
|
||||
|
||||
class WorkerSettings:
|
||||
|
|
|
|||
|
|
@ -166,18 +166,22 @@ async def process_knowledge_base_document(
|
|||
user_id=document.created_by,
|
||||
organization_id=document.organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
effective_config = resolved_config.effective
|
||||
if effective_config.embeddings:
|
||||
embeddings_provider = getattr(
|
||||
effective_config.embeddings, "provider", None
|
||||
)
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(
|
||||
effective_config.embeddings, "endpoint", None
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
logger.info(
|
||||
f"Using user embeddings config: provider={embeddings_provider}, "
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
||||
async def upload_voicemail_audio_to_s3(
|
||||
|
|
@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3(
|
|||
logger.warning(
|
||||
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
This task combines audio upload, transcript upload, and webhook integrations
|
||||
into a single sequential task to ensure integrations run after uploads complete.
|
||||
|
||||
Args:
|
||||
_ctx: ARQ context (unused)
|
||||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Processing workflow completion for run {workflow_run_id}")
|
||||
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
try:
|
||||
if os.path.exists(transcript_temp_path):
|
||||
file_size = os.path.getsize(transcript_temp_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
logger.info(
|
||||
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded transcript: {transcript_url}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcript temp file not found: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if transcript_temp_path and os.path.exists(transcript_temp_path):
|
||||
try:
|
||||
os.remove(transcript_temp_path)
|
||||
logger.debug(
|
||||
f"Cleaned up temp transcript file: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Calculate cost after integrations (so QA token usage is included)
|
||||
try:
|
||||
await calculate_workflow_run_cost(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
|
|||
121
api/tasks/workflow_completion.py
Normal file
121
api/tasks/workflow_completion.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
)
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
This task combines audio upload, transcript upload, and webhook integrations
|
||||
into a single sequential task to ensure integrations run after uploads complete.
|
||||
|
||||
Args:
|
||||
_ctx: ARQ context (unused)
|
||||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Processing workflow completion for run {workflow_run_id}")
|
||||
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
try:
|
||||
if os.path.exists(transcript_temp_path):
|
||||
file_size = os.path.getsize(transcript_temp_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
logger.info(
|
||||
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded transcript: {transcript_url}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcript temp file not found: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if transcript_temp_path and os.path.exists(transcript_temp_path):
|
||||
try:
|
||||
os.remove(transcript_temp_path)
|
||||
logger.debug(
|
||||
f"Cleaned up temp transcript file: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Notify MPS after completion. MPS owns credit accounting.
|
||||
try:
|
||||
await report_completed_workflow_run_platform_usage(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error reporting platform usage for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
@ -203,7 +203,7 @@ async def create_workflow_run_rows(
|
|||
Returns:
|
||||
Tuple of (workflow_run, user, workflow).
|
||||
"""
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
|
||||
async_session.add(org)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
|
|
@ -15,6 +19,7 @@ from api.services.configuration.ai_model_configuration import (
|
|||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
|
|
@ -22,6 +27,8 @@ from api.services.configuration.registry import (
|
|||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleLLMService,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
|
@ -49,6 +56,7 @@ def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
|||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
|
|
@ -92,6 +100,67 @@ def test_byok_v2_rejects_dograh_provider():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
|
||||
config = OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "realtime",
|
||||
"realtime": {
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "google-realtime-key",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Puck",
|
||||
"language": "en",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "google",
|
||||
"api_key": "google-llm-key",
|
||||
"model": "gemini-2.0-flash",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is True
|
||||
assert effective.stt is None
|
||||
assert effective.tts is None
|
||||
assert await UserConfigurationValidator().validate(effective) == {
|
||||
"status": [{"model": "all", "message": "ok"}]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
|
||||
effective = EffectiveAIModelConfiguration(
|
||||
llm=GoogleLLMService(
|
||||
provider="google",
|
||||
api_key="google-llm-key",
|
||||
model="gemini-2.0-flash",
|
||||
),
|
||||
realtime=GoogleRealtimeLLMConfiguration(
|
||||
provider="google_realtime",
|
||||
api_key="google-realtime-key",
|
||||
model="gemini-3.1-flash-live-preview",
|
||||
voice="Puck",
|
||||
language="en",
|
||||
),
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await UserConfigurationValidator().validate(effective)
|
||||
|
||||
assert exc_info.value.args[0] == [
|
||||
{"model": "stt", "message": "API key is missing"},
|
||||
{"model": "tts", "message": "API key is missing"},
|
||||
]
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
|
|
@ -293,3 +362,98 @@ def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
|||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
from api.routes import organization as organization_routes
|
||||
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
)
|
||||
expected_response = OrganizationAIModelConfigurationResponse(
|
||||
configuration={"version": 2, "mode": "dograh"},
|
||||
effective_configuration={},
|
||||
source="organization_v2",
|
||||
)
|
||||
|
||||
class FakeValidator:
|
||||
async def validate(self, *args, **kwargs):
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
upsert = AsyncMock()
|
||||
migrate_workflows = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"get_organization_ai_model_configuration_v2",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=legacy),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"UserConfigurationValidator",
|
||||
lambda: FakeValidator(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"upsert_organization_ai_model_configuration_v2",
|
||||
upsert,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"migrate_workflow_model_configurations_to_v2",
|
||||
migrate_workflows,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"_model_configuration_v2_response",
|
||||
AsyncMock(return_value=expected_response),
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_routes.migrate_model_configuration_v2(
|
||||
force=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
|
||||
upsert.assert_awaited_once()
|
||||
migrate_workflows.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
assert response == expected_response
|
||||
|
|
|
|||
68
api/tests/test_auth_depends.py
Normal file
68
api/tests/test_auth_depends.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.auth import depends as auth_depends
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
|
||||
stack_user = {
|
||||
"id": "stack-user-1",
|
||||
"selected_team_id": "team-1",
|
||||
"primary_email_verified": False,
|
||||
}
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
email=None,
|
||||
provider_id="stack-user-1",
|
||||
selected_organization_id=None,
|
||||
)
|
||||
organization = SimpleNamespace(id=42)
|
||||
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
|
||||
monkeypatch.setattr(
|
||||
auth_depends.stackauth,
|
||||
"get_user",
|
||||
AsyncMock(return_value=stack_user),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_user_by_provider_id",
|
||||
AsyncMock(return_value=(user, False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_organization_by_provider_id",
|
||||
AsyncMock(return_value=(organization, True)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"add_user_to_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"update_user_selected_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=existing_config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
|
||||
result = await auth_depends.get_user(authorization="Bearer token")
|
||||
|
||||
assert result is user
|
||||
assert result.selected_organization_id == 42
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
110
api/tests/test_dograh_managed_correlation.py
Normal file
110
api/tests/test_dograh_managed_correlation.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
from pipecat.frames.frames import TTSStartedFrame
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from websockets.protocol import State
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.state = State.OPEN
|
||||
self.messages: list[dict] = []
|
||||
|
||||
async def send(self, message: str) -> None:
|
||||
self.messages.append(json.loads(message))
|
||||
|
||||
async def close(self, *args, **kwargs) -> None:
|
||||
self.state = State.CLOSED
|
||||
|
||||
|
||||
def test_dograh_llm_uses_explicit_mps_correlation_id():
|
||||
service = DograhLLMService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
settings=OpenAILLMSettings(model="default"),
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
params = service.build_chat_completion_params(
|
||||
{
|
||||
"messages": [],
|
||||
"tools": OPENAI_NOT_GIVEN,
|
||||
"tool_choice": OPENAI_NOT_GIVEN,
|
||||
}
|
||||
)
|
||||
|
||||
assert params["metadata"]["correlation_id"] == "mps-corr-123"
|
||||
assert params["metadata"]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.stt.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhSTTService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=16000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.tts.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhTTSService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=24000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
async def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service.audio_context_available = lambda context_id: False
|
||||
service.create_audio_context = _noop
|
||||
service.start_ttfb_metrics = _noop
|
||||
service.start_tts_usage_metrics = _noop
|
||||
|
||||
frames = []
|
||||
async for frame in service.run_tts("hello", "ctx-1"):
|
||||
frames.append(frame)
|
||||
|
||||
assert isinstance(frames[0], TTSStartedFrame)
|
||||
assert fake_ws.messages[1]["type"] == "create_context"
|
||||
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[1]["mps_billing_version"] == "2"
|
||||
|
|
@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
|||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
|
|
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
|
|
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
|
|
|
|||
|
|
@ -87,3 +87,317 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"correlation_id": "mps-corr-123"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.create_correlation_id(
|
||||
service_key="mps_sk_paid",
|
||||
workflow_run_id=42,
|
||||
) == {"correlation_id": "mps-corr-123"}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
{"workflow_run_id": 42},
|
||||
{
|
||||
"Authorization": "Bearer mps_sk_paid",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_billing_account_status(organization_id=42) == {
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/status",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.ensure_billing_account_v2(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/balance",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, params, headers):
|
||||
calls.append(("GET", url, params, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_credit_ledger(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
) == {
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
|
||||
{"page": 3, "limit": 25},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"correlation_id": "mps-corr-123",
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"duration_seconds": 87.0,
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
|||
99
api/tests/test_organization_usage_billing.py
Normal file
99
api/tests/test_organization_usage_billing.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.routes import organization_usage
|
||||
|
||||
|
||||
def test_is_mps_billing_v2_depends_only_on_account_mode():
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
|
||||
assert organization_usage._is_mps_billing_v2(None) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(provider_id="provider-123")
|
||||
|
||||
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
|
||||
"billing_mode": "v2"
|
||||
}
|
||||
get_status.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
|
||||
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_usage,
|
||||
"_get_mps_billing_account_status",
|
||||
AsyncMock(return_value={"billing_mode": "v2"}),
|
||||
)
|
||||
get_ledger = AsyncMock(
|
||||
return_value={
|
||||
"account": {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": 250,
|
||||
"currency": "USD",
|
||||
},
|
||||
"ledger_entries": [
|
||||
{
|
||||
"id": 99,
|
||||
"entry_type": "grant",
|
||||
"origin": "account_creation",
|
||||
"credits_delta": 250,
|
||||
"balance_after": 250,
|
||||
"created_at": "2026-06-12T00:00:00Z",
|
||||
}
|
||||
],
|
||||
"total_debits_credits": 75,
|
||||
"total_count": 101,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 5,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_credit_ledger",
|
||||
get_ledger,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_usage.get_billing_credits(
|
||||
page=3,
|
||||
limit=25,
|
||||
user=user,
|
||||
)
|
||||
|
||||
get_ledger.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
created_by="provider-123",
|
||||
)
|
||||
assert response.billing_version == "v2"
|
||||
assert response.total_credits_used == 75
|
||||
assert response.total_count == 101
|
||||
assert response.page == 3
|
||||
assert response.limit == 25
|
||||
assert response.total_pages == 5
|
||||
assert response.ledger_entries[0].id == 99
|
||||
|
|
@ -9,7 +9,7 @@ Module under test: api.services.configuration.resolve
|
|||
|
||||
import pytest
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
contains_masked_key,
|
||||
mask_workflow_configurations,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.workflow.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
|||
from websockets.exceptions import ConnectionClosedError
|
||||
from websockets.frames import Close
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
_RESUMPTION_USER_MESSAGE,
|
||||
|
|
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_ultravox_realtime_service():
|
||||
user_config = EffectiveAIModelConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
|
|
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
212
api/tests/test_workflow_run_billing.py
Normal file
212
api/tests/test_workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import workflow_run_billing as workflow_run_billing_mod
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
report_workflow_run_platform_usage,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
is_completed=True,
|
||||
initial_context={"mps_correlation_id": "mps-corr-123"},
|
||||
usage_info={"call_duration_seconds": 87},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
duration_seconds=None,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "mps_correlation",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id=None,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "dograh_usage_info",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v1"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_awaited_once_with(organization_id=42)
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
workflow_run.usage_info = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_not_awaited()
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.is_completed = False
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_run = AsyncMock(return_value=workflow_run)
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
get_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_completed_workflow_run_platform_usage(workflow_run.id)
|
||||
|
||||
get_run.assert_awaited_once_with(workflow_run.id)
|
||||
report_usage.assert_awaited_once()
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.pricing import workflow_run_cost as workflow_run_cost_mod
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
calculate_workflow_run_cost,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
mode="textchat",
|
||||
created_at=datetime.now(UTC),
|
||||
usage_info={
|
||||
"llm": {},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
},
|
||||
cost_info={},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_workflow_run_cost_info_does_not_update_org_usage(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
|
||||
assert cost_info is not None
|
||||
assert cost_info["call_duration_seconds"] == 7
|
||||
assert "cost_breakdown" in cost_info
|
||||
assert "dograh_token_usage" in cost_info
|
||||
assert cost_info["charge_usd"] == 10.5
|
||||
update_usage.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_workflow_run_cost_keeps_org_usage_side_effect_in_wrapper(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=None))
|
||||
update_run = AsyncMock()
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_workflow_run", update_run
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
await calculate_workflow_run_cost(workflow_run.id)
|
||||
|
||||
update_run.assert_awaited_once()
|
||||
saved_kwargs = update_run.await_args.kwargs
|
||||
assert saved_kwargs["run_id"] == workflow_run.id
|
||||
assert "cost_breakdown" in saved_kwargs["cost_info"]
|
||||
update_usage.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_usage_delta_to_organization_uses_incremental_costs(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.cost_info = {"call_id": "preserve-me"}
|
||||
|
||||
usage_delta_one = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 1_000,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 1_100,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 3,
|
||||
}
|
||||
usage_delta_two = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 2_000,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 2_050,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 4,
|
||||
}
|
||||
merged_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 3_000,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 3_150,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {},
|
||||
"stt": {},
|
||||
"call_duration_seconds": 7,
|
||||
}
|
||||
|
||||
get_org = AsyncMock(return_value=SimpleNamespace(id=42, price_per_second_usd=1.5))
|
||||
update_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "get_organization_by_id", get_org
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_cost_mod.db_client, "update_usage_after_run", update_usage
|
||||
)
|
||||
|
||||
first_delta = await apply_usage_delta_to_organization(workflow_run, usage_delta_one)
|
||||
second_delta = await apply_usage_delta_to_organization(
|
||||
workflow_run, usage_delta_two
|
||||
)
|
||||
total_workflow_run = SimpleNamespace(**workflow_run.__dict__)
|
||||
total_workflow_run.usage_info = merged_usage
|
||||
total_cost = await build_workflow_run_cost_info(total_workflow_run)
|
||||
|
||||
assert first_delta is not None
|
||||
assert second_delta is not None
|
||||
assert total_cost is not None
|
||||
assert update_usage.await_count == 2
|
||||
assert update_usage.await_args_list[0].args == (
|
||||
42,
|
||||
first_delta["dograh_token_usage"],
|
||||
3.0,
|
||||
first_delta["charge_usd"],
|
||||
)
|
||||
assert update_usage.await_args_list[1].args == (
|
||||
42,
|
||||
second_delta["dograh_token_usage"],
|
||||
4.0,
|
||||
second_delta["charge_usd"],
|
||||
)
|
||||
assert (
|
||||
first_delta["dograh_token_usage"] + second_delta["dograh_token_usage"]
|
||||
) == pytest.approx(total_cost["dograh_token_usage"])
|
||||
assert (
|
||||
first_delta["charge_usd"] + second_delta["charge_usd"]
|
||||
== total_cost["charge_usd"]
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from api.db.models import OrganizationModel, UserModel
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.tests.integrations._run_pipeline_helpers import USER_CONFIGURATION
|
||||
from pipecat.tests import MockLLMService
|
||||
|
||||
|
|
@ -176,11 +176,7 @@ async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
|||
assert "Start" in (created["gathered_context"] or {}).get("nodes_visited", [])
|
||||
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.cost_info[
|
||||
"call_duration_seconds"
|
||||
] == workflow_run.usage_info.get("call_duration_seconds", 0)
|
||||
assert "cost_breakdown" in workflow_run.cost_info
|
||||
assert "dograh_token_usage" in workflow_run.cost_info
|
||||
assert "call_duration_seconds" in workflow_run.usage_info
|
||||
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
|
||||
"Hello from the workflow tester."
|
||||
]
|
||||
|
|
@ -296,11 +292,7 @@ async def test_text_chat_message_executes_assistant_turn(
|
|||
assert "Start" in (payload["gathered_context"] or {}).get("nodes_visited", [])
|
||||
workflow_run = await db_session.get_workflow_run_by_id(created["workflow_run_id"])
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.cost_info[
|
||||
"call_duration_seconds"
|
||||
] == workflow_run.usage_info.get("call_duration_seconds", 0)
|
||||
assert "cost_breakdown" in workflow_run.cost_info
|
||||
assert "dograh_token_usage" in workflow_run.cost_info
|
||||
assert "call_duration_seconds" in workflow_run.usage_info
|
||||
assert _log_texts(run_payload["logs"], "rtf-user-transcription") == ["Hi there"]
|
||||
assert _log_texts(run_payload["logs"], "rtf-bot-text") == [
|
||||
"Welcome to the workflow tester.",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue