Merge remote-tracking branch 'origin/main' into feat/user-onboarding

This commit is contained in:
Abhishek Kumar 2026-06-12 18:54:48 +05:30
commit 093e888ce4
148 changed files with 10908 additions and 2815 deletions

View file

@ -1,3 +1,3 @@
{
".": "1.33.0"
".": "1.34.0"
}

View file

@ -1,5 +1,33 @@
# Changelog
## 1.34.0 (2026-06-03)
<!-- Release notes generated using configuration in .github/release.yml at main -->
## What's Changed
### Features
* feat: add mcp guides for various topic and stages for bot building by @a6kme in https://github.com/dograh-hq/dograh/pull/380
* feat: allow overriding base URL of OpenAI STT and TTS by @developer603 in https://github.com/dograh-hq/dograh/pull/377
* feat: add Azure AI multi-provider support (TTS, STT, Embeddings, Realtime) by @vishaldhateria in https://github.com/dograh-hq/dograh/pull/381
### Bug Fixes
* fix: support object and array parameters in custom HTTP tools by @mvanhorn in https://github.com/dograh-hq/dograh/pull/373
* fix(telephony): resolve transfer context via call-sid index instead of KEYS scan by @shiminshen in https://github.com/dograh-hq/dograh/pull/387
* fix(webrtc): enforce embed allowed-domain policy on public signaling websocket by @shiminshen in https://github.com/dograh-hq/dograh/pull/388
* fix: use runtime BACKEND_URL for proxying by @a6kme in https://github.com/dograh-hq/dograh/pull/411
* fix: add CORS preflight handler and ACAO header for embed config endpoint by @nuthalapativarun in https://github.com/dograh-hq/dograh/pull/403
### Other Changes
* Add Sarvam LLM, update Sarvam STT models, expose usage_info on run detail by @abhaybabbar in https://github.com/dograh-hq/dograh/pull/351
* fix: make email lookup case-insensitive in get_user_by_email by @developer603 in https://github.com/dograh-hq/dograh/pull/397
## New Contributors
* @abhaybabbar made their first contribution in https://github.com/dograh-hq/dograh/pull/351
* @mvanhorn made their first contribution in https://github.com/dograh-hq/dograh/pull/373
* @developer603 made their first contribution in https://github.com/dograh-hq/dograh/pull/377
* @vishaldhateria made their first contribution in https://github.com/dograh-hq/dograh/pull/381
* @shiminshen made their first contribution in https://github.com/dograh-hq/dograh/pull/387
**Full Changelog**: https://github.com/dograh-hq/dograh/compare/dograh-v1.33.0...dograh-v1.34.0
## [1.33.0](https://github.com/dograh-hq/dograh/compare/dograh-v1.32.0...dograh-v1.33.0) (2026-05-31)

View file

@ -75,13 +75,13 @@ An honest comparison on the axes that matter most to teams evaluating voice AI p
##### Download and setup Dograh on your Local Machine
> **Note**
> We collect anonymous usage data to improve the product. You can opt out by setting the `ENABLE_TELEMETRY` to `false` in the below command.
> We collect anonymous usage data to improve the product. You can opt out by setting `ENABLE_TELEMETRY=false` before running the startup script.
> **Note**
> If you wish to run the platform on a remote server instead, checkout our [Documentation](https://docs.dograh.com/deployment/docker#option-2:-remote-server-deployment)
```bash
curl -o docker-compose.yaml https://raw.githubusercontent.com/dograh-hq/dograh/main/docker-compose.yaml && REGISTRY=ghcr.io/dograh-hq ENABLE_TELEMETRY=true docker compose up --pull always
curl -o docker-compose.yaml https://raw.githubusercontent.com/dograh-hq/dograh/main/docker-compose.yaml && curl -o start_docker.sh https://raw.githubusercontent.com/dograh-hq/dograh/main/scripts/start_docker.sh && chmod +x start_docker.sh && ./start_docker.sh
```
> **Note**

View file

@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# 1) Create the `quota_type` enum *before* we add the column that references it.
@ -34,7 +37,12 @@ def upgrade() -> None:
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("period_start", sa.DateTime(), nullable=False),
sa.Column("period_end", sa.DateTime(), nullable=False),
sa.Column("quota_dograh_tokens", sa.Integer(), nullable=False),
sa.Column(
"quota_dograh_tokens",
sa.Integer(),
nullable=False,
comment=DEPRECATED_QUOTA_COMMENT,
),
sa.Column("used_dograh_tokens", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
@ -63,7 +71,11 @@ def upgrade() -> None:
op.add_column(
"organizations",
sa.Column(
"quota_type", quota_type_enum, nullable=False, server_default="monthly"
"quota_type",
quota_type_enum,
nullable=False,
server_default="monthly",
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
@ -73,6 +85,7 @@ def upgrade() -> None:
sa.Integer(),
nullable=False,
server_default=sa.text("0"),
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
@ -82,10 +95,17 @@ def upgrade() -> None:
sa.Integer(),
nullable=False,
server_default=sa.text("LEAST(EXTRACT(DAY FROM CURRENT_DATE)::int, 28)"),
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
"organizations", sa.Column("quota_start_date", sa.DateTime(), nullable=True)
"organizations",
sa.Column(
"quota_start_date",
sa.DateTime(),
nullable=True,
comment=DEPRECATED_QUOTA_COMMENT,
),
)
op.add_column(
"organizations",
@ -94,6 +114,7 @@ def upgrade() -> None:
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
comment=DEPRECATED_QUOTA_COMMENT,
),
)
# ### end Alembic commands ###

View file

@ -5,28 +5,38 @@ Revises: 6bd9f67ec994
Create Date: 2026-06-02 07:58:00.002359
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '384be6596b36'
down_revision: Union[str, None] = '6bd9f67ec994'
revision: str = "384be6596b36"
down_revision: Union[str, None] = "6bd9f67ec994"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_email'), table_name='users')
op.create_index('ix_users_email_lower', 'users', [sa.literal_column('lower(email)')], unique=True, postgresql_where=sa.text('email IS NOT NULL'))
op.drop_index(op.f("ix_users_email"), table_name="users")
op.create_index(
"ix_users_email_lower",
"users",
[sa.literal_column("lower(email)")],
unique=True,
postgresql_where=sa.text("email IS NOT NULL"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_users_email_lower', table_name='users', postgresql_where=sa.text('email IS NOT NULL'))
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.drop_index(
"ix_users_email_lower",
table_name="users",
postgresql_where=sa.text("email IS NOT NULL"),
)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
# ### end Alembic commands ###

View file

@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
@ -26,7 +29,12 @@ def upgrade() -> None:
)
op.add_column(
"organization_usage_cycles",
sa.Column("quota_amount_usd", sa.Float(), nullable=True),
sa.Column(
"quota_amount_usd",
sa.Float(),
nullable=True,
comment=DEPRECATED_QUOTA_COMMENT,
),
)
# ### end Alembic commands ###

View file

@ -117,6 +117,15 @@ app.add_middleware(
allow_headers=["*"],
)
def _add_public_embed_cors_middleware() -> None:
from api.routes.public_embed import PublicEmbedCORSMiddleware
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix=API_PREFIX)
_add_public_embed_cors_middleware()
api_router = APIRouter()
# include subrouters here

View file

@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.workflow.run_usage_response import format_public_cost_info
class CampaignClient(BaseDBClient):
@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient):
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(
run.cost_info, run.usage_info
),
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,
@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient):
async with self.async_session() as session:
conditions = [
WorkflowRunModel.is_completed.is_(True),
WorkflowRunModel.cost_info["call_duration_seconds"]
WorkflowRunModel.usage_info["call_duration_seconds"]
.as_string()
.isnot(None),
]
@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient):
WorkflowRunModel.initial_context,
WorkflowRunModel.gathered_context,
WorkflowRunModel.cost_info,
WorkflowRunModel.usage_info,
WorkflowRunModel.public_access_token,
)
.where(*conditions)

View file

@ -53,7 +53,7 @@ class DBClient(
- UserClient: handles user and user configuration operations
- OrganizationClient: handles organization operations
- OrganizationConfigurationClient: handles organization configuration operations
- OrganizationUsageClient: handles organization usage and quota operations
- OrganizationUsageClient: handles organization usage reporting aggregates
- IntegrationClient: handles integration operations
- WorkflowTemplateClient: handles workflow template operations
- CampaignClient: handles campaign operations

View file

@ -25,7 +25,7 @@ def get_workflow_run_order_clause(
"""
# Determine sort column
if sort_by == "duration":
sort_column = WorkflowRunModel.cost_info.op("->>")(
sort_column = WorkflowRunModel.usage_info.op("->>")(
"call_duration_seconds"
).cast(Float)
else:
@ -43,7 +43,7 @@ def get_workflow_run_order_clause(
ATTRIBUTE_FIELD_MAPPING = {
"dateRange": "created_at",
"dispositionCode": "gathered_context.mapped_call_disposition",
"duration": "cost_info.call_duration_seconds",
"duration": "usage_info.call_duration_seconds",
"status": "is_completed",
"tokenUsage": "cost_info.total_cost_usd",
"runId": "id",
@ -208,7 +208,7 @@ def apply_workflow_run_filters(
min_val = value.get("min")
max_val = value.get("max")
if field == "cost_info.call_duration_seconds":
if field == "usage_info.call_duration_seconds":
# Use ->> operator for compatibility with all PostgreSQL versions
# (subscript [] only works in PostgreSQL 14+)
duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")(

View file

@ -97,22 +97,44 @@ class OrganizationModel(Base):
provider_id = Column(String, unique=True, index=True, nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
# Quota fields
# Deprecated: MPS owns quota and credit ledger state.
quota_type = Column(
Enum("monthly", "annual", name="quota_type"),
nullable=False,
default="monthly",
server_default=text("'monthly'::quota_type"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_dograh_tokens = Column(
Integer, nullable=False, default=0, server_default=text("0")
Integer,
nullable=False,
default=0,
server_default=text("0"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_reset_day = Column(
Integer, nullable=False, default=1, server_default=text("1")
) # 1-28, only for monthly
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
Integer,
nullable=False,
default=1,
server_default=text("1"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_start_date = Column(
DateTime(timezone=True),
nullable=True,
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
quota_enabled = Column(
Boolean, nullable=False, default=False, server_default=text("false")
Boolean,
nullable=False,
default=False,
server_default=text("false"),
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
price_per_second_usd = Column(Float, nullable=True)
@ -593,8 +615,9 @@ class WorkflowRunTextSessionModel(Base):
class OrganizationUsageCycleModel(Base):
"""
This model is used to track the usage of Dograh tokens for an organization for a given usage
cycle.
This model is used to track reporting aggregates for an organization for a given
usage cycle. Quota fields on this model are deprecated; MPS owns quota and
credit ledger state.
"""
__tablename__ = "organization_usage_cycles"
@ -603,14 +626,24 @@ class OrganizationUsageCycleModel(Base):
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
period_start = Column(DateTime(timezone=True), nullable=False)
period_end = Column(DateTime(timezone=True), nullable=False)
quota_dograh_tokens = Column(Integer, nullable=False)
quota_dograh_tokens = Column(
Integer,
nullable=False,
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
used_dograh_tokens = Column(Float, nullable=False, default=0)
total_duration_seconds = Column(
Integer, nullable=False, default=0, server_default=text("0")
)
# New USD tracking fields
used_amount_usd = Column(Float, nullable=True, default=0)
quota_amount_usd = Column(Float, nullable=True)
quota_amount_usd = Column(
Float,
nullable=True,
comment="Deprecated. MPS owns quota and credit ledger state.",
info={"deprecated": True},
)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),

View file

@ -10,6 +10,7 @@ from sqlalchemy.orm import joinedload
from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters
from api.db.models import (
OrganizationConfigurationModel,
OrganizationModel,
OrganizationUsageCycleModel,
UserConfigurationModel,
@ -17,11 +18,12 @@ from api.db.models import (
WorkflowModel,
WorkflowRunModel,
)
from api.schemas.user_configuration import UserConfiguration
from api.enums import OrganizationConfigurationKey
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
class OrganizationUsageClient(BaseDBClient):
"""Client for managing organization usage and quota operations."""
"""Client for managing organization usage reporting aggregates."""
async def get_or_create_current_cycle(
self, organization_id: int, session=None
@ -47,14 +49,7 @@ class OrganizationUsageClient(BaseDBClient):
self, organization_id: int, session, commit: bool
) -> OrganizationUsageCycleModel:
"""Internal implementation for get_or_create_current_cycle."""
# Get organization to determine quota type
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one()
# Calculate current period
period_start, period_end = self._calculate_current_period(org)
period_start, period_end = self._calculate_current_period()
# Try to get existing cycle
cycle_result = await session.execute(
@ -76,7 +71,8 @@ class OrganizationUsageClient(BaseDBClient):
organization_id=organization_id,
period_start=period_start,
period_end=period_end,
quota_dograh_tokens=org.quota_dograh_tokens,
# Deprecated non-null column retained for historical schema compatibility.
quota_dograh_tokens=0,
)
# Handle concurrent inserts gracefully
stmt = stmt.on_conflict_do_nothing(
@ -100,95 +96,9 @@ class OrganizationUsageClient(BaseDBClient):
)
return cycle_result.scalar_one()
async def check_and_reserve_quota(
self, organization_id: int, estimated_tokens: int = 0
) -> bool:
"""
Check if organization has sufficient quota and optionally reserve tokens.
Returns True if quota is available, False otherwise.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get organization
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = org_result.scalar_one_or_none()
if not org or not org.quota_enabled:
# No quota enforcement if not enabled
return True
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Atomic check and update with row-level lock
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(
and_(
OrganizationUsageCycleModel.id == cycle.id,
OrganizationUsageCycleModel.used_dograh_tokens
+ estimated_tokens
<= OrganizationUsageCycleModel.quota_dograh_tokens,
)
)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one_or_none()
if cycle_locked:
# Update the usage atomically
cycle_locked.used_dograh_tokens += estimated_tokens
await session.commit()
return True
return False
async def update_usage_after_run(
self,
organization_id: int,
actual_tokens: float,
duration_seconds: float = 0,
charge_usd: float | None = None,
) -> None:
"""Update usage after a workflow run completes with actual token count and duration.
This method is fully atomic and safe for concurrent access from multiple processes.
"""
async with self.async_session() as session:
# Get or create current cycle within the same session/transaction
cycle = await self._get_or_create_current_cycle_impl(
organization_id, session, commit=False
)
# Acquire a row-level lock for atomic update
result = await session.execute(
select(OrganizationUsageCycleModel)
.where(OrganizationUsageCycleModel.id == cycle.id)
.with_for_update(skip_locked=False)
)
cycle_locked = result.scalar_one()
# Update usage atomically
cycle_locked.used_dograh_tokens += actual_tokens
cycle_locked.total_duration_seconds += int(round(duration_seconds))
# Update USD amount if provided
if charge_usd is not None:
if cycle_locked.used_amount_usd is None:
cycle_locked.used_amount_usd = 0
cycle_locked.used_amount_usd += charge_usd
await session.commit()
async def get_current_usage(self, organization_id: int) -> dict:
"""Get current period usage information."""
"""Get current reporting-period usage information."""
async with self.async_session() as session:
# Get organization
org_result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
@ -199,42 +109,19 @@ class OrganizationUsageClient(BaseDBClient):
organization_id, session, commit=False
)
# Calculate next refresh date
if org.quota_type == "monthly":
next_refresh = cycle.period_end + relativedelta(days=1)
else: # annual
next_refresh = cycle.period_end + relativedelta(days=1)
result = {
"period_start": cycle.period_start.isoformat(),
"period_end": cycle.period_end.isoformat(),
"used_dograh_tokens": cycle.used_dograh_tokens,
"quota_dograh_tokens": cycle.quota_dograh_tokens,
"percentage_used": (
round(
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
)
if cycle.quota_dograh_tokens > 0
else 0
),
"next_refresh_date": next_refresh.date().isoformat(),
"quota_enabled": org.quota_enabled,
"total_duration_seconds": cycle.total_duration_seconds,
}
# Add USD fields if organization has pricing
if org.price_per_second_usd is not None:
result["used_amount_usd"] = cycle.used_amount_usd or 0
result["quota_amount_usd"] = cycle.quota_amount_usd
result["currency"] = "USD"
result["price_per_second_usd"] = org.price_per_second_usd
# Calculate percentage based on USD if available
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
result["percentage_used"] = round(
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
)
return result
async def get_usage_history(
@ -254,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient):
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
WorkflowRunModel.usage_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
@ -307,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient):
total_tokens = 0
total_duration_seconds = 0
for run in runs:
if run.cost_info:
# Try to get dograh_token_usage first (new format)
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
# If not present, calculate from total_cost_usd (old format)
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
dograh_tokens = round(
float(run.cost_info["total_cost_usd"]) * 100, 2
)
# Get call duration
call_duration = run.cost_info.get("call_duration_seconds", 0)
else:
dograh_tokens = 0
call_duration = 0
dograh_tokens = 0
call_duration = (run.usage_info or {}).get("call_duration_seconds", 0)
total_tokens += dograh_tokens
total_duration_seconds += int(round(call_duration))
@ -393,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient):
WorkflowRunModel.initial_context,
WorkflowRunModel.gathered_context,
WorkflowRunModel.cost_info,
WorkflowRunModel.usage_info,
WorkflowRunModel.public_access_token,
)
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
.join(UserModel, WorkflowModel.user_id == UserModel.id)
.where(
UserModel.selected_organization_id == organization_id,
WorkflowRunModel.cost_info.isnot(None),
WorkflowRunModel.usage_info.isnot(None),
)
.order_by(WorkflowRunModel.created_at.desc())
)
@ -440,8 +317,29 @@ class OrganizationUsageClient(BaseDBClient):
"""Get daily usage breakdown for an organization with pricing."""
async with self.async_session() as session:
# Get user timezone if user_id is provided
# Get org timezone preference first, then fall back to legacy user config.
user_timezone = "UTC" # Default timezone
pref_result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key.in_(
[
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
]
),
)
)
pref_rows = pref_result.scalars().all()
pref_by_key = {pref.key: pref for pref in pref_rows}
pref_obj = pref_by_key.get(
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value
) or pref_by_key.get(
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value
)
if pref_obj and pref_obj.value:
user_timezone = pref_obj.value.get("timezone") or user_timezone
if user_id:
config_result = await session.execute(
select(UserConfigurationModel).where(
@ -450,11 +348,11 @@ class OrganizationUsageClient(BaseDBClient):
)
config_obj = config_result.scalar_one_or_none()
if config_obj and config_obj.configuration:
user_config = UserConfiguration.model_validate(
effective_config = EffectiveAIModelConfiguration.model_validate(
config_obj.configuration
)
if user_config.timezone:
user_timezone = user_config.timezone
if effective_config.timezone and user_timezone == "UTC":
user_timezone = effective_config.timezone
# Validate timezone string
try:
@ -473,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient):
select(
date_expr.label("date"),
func.sum(
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
WorkflowRunModel.usage_info["call_duration_seconds"].as_float()
).label("total_seconds"),
func.count(WorkflowRunModel.id).label("call_count"),
)
@ -522,83 +420,11 @@ class OrganizationUsageClient(BaseDBClient):
"currency": "USD",
}
async def update_organization_quota(
self,
organization_id: int,
quota_type: str,
quota_dograh_tokens: int,
quota_reset_day: Optional[int] = None,
quota_start_date: Optional[datetime] = None,
) -> OrganizationModel:
"""Update organization quota settings."""
async with self.async_session() as session:
result = await session.execute(
select(OrganizationModel).where(OrganizationModel.id == organization_id)
)
org = result.scalar_one()
org.quota_type = quota_type
org.quota_dograh_tokens = quota_dograh_tokens
org.quota_enabled = True
if quota_type == "monthly" and quota_reset_day:
org.quota_reset_day = quota_reset_day
elif quota_type == "annual" and quota_start_date:
org.quota_start_date = quota_start_date
await session.commit()
await session.refresh(org)
return org
def _calculate_current_period(
self, org: OrganizationModel
) -> tuple[datetime, datetime]:
"""Calculate the current billing period based on organization settings."""
def _calculate_current_period(self) -> tuple[datetime, datetime]:
"""Calculate the current calendar-month reporting period."""
now = datetime.now(timezone.utc)
if org.quota_type == "monthly":
# Find the start of the current billing month
reset_day = org.quota_reset_day
# Handle month boundaries
if now.day >= reset_day:
period_start = now.replace(
day=reset_day, hour=0, minute=0, second=0, microsecond=0
)
else:
# Previous month
period_start = (now - relativedelta(months=1)).replace(
day=reset_day, hour=0, minute=0, second=0, microsecond=0
)
# End is one month later minus 1 second
period_end = (
period_start + relativedelta(months=1) - relativedelta(seconds=1)
)
else: # annual
if not org.quota_start_date:
# Default to calendar year
period_start = now.replace(
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
period_end = (
period_start + relativedelta(years=1) - relativedelta(seconds=1)
)
else:
# Find current annual period
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
years_diff = now.year - start_date.year
# Adjust for whether we've passed the anniversary
if now.month < start_date.month or (
now.month == start_date.month and now.day < start_date.day
):
years_diff -= 1
period_start = start_date + relativedelta(years=years_diff)
period_end = (
period_start + relativedelta(years=1) - relativedelta(seconds=1)
)
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
period_end = period_start + relativedelta(months=1) - relativedelta(seconds=1)
return period_start, period_end

View file

@ -8,7 +8,7 @@ from sqlalchemy.future import select
from api.db.base_client import BaseDBClient
from api.db.models import UserConfigurationModel, UserModel
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
class UserClient(BaseDBClient):
@ -65,7 +65,9 @@ class UserClient(BaseDBClient):
)
return result.scalars().first()
async def get_user_configurations(self, user_id: int) -> UserConfiguration:
async def get_user_configurations(
self, user_id: int
) -> EffectiveAIModelConfiguration:
async with self.async_session() as session:
result = await session.execute(
select(UserConfigurationModel).where(
@ -74,10 +76,10 @@ class UserClient(BaseDBClient):
)
configuration_obj = result.scalars().first()
if not configuration_obj:
return UserConfiguration()
return EffectiveAIModelConfiguration()
try:
return UserConfiguration.model_validate(
return EffectiveAIModelConfiguration.model_validate(
{
**configuration_obj.configuration,
"last_validated_at": configuration_obj.last_validated_at,
@ -90,11 +92,11 @@ class UserClient(BaseDBClient):
f"Failed to validate user configuration for user {user_id}: {e}. "
"Returning default configuration."
)
return UserConfiguration()
return EffectiveAIModelConfiguration()
async def update_user_configuration(
self, user_id: int, configuration: UserConfiguration
) -> UserConfiguration:
self, user_id: int, configuration: EffectiveAIModelConfiguration
) -> EffectiveAIModelConfiguration:
async with self.async_session() as session:
result = await session.execute(
select(UserConfigurationModel).where(
@ -115,7 +117,9 @@ class UserClient(BaseDBClient):
await session.rollback()
raise e
await session.refresh(configuration_obj)
return UserConfiguration.model_validate(configuration_obj.configuration)
return EffectiveAIModelConfiguration.model_validate(
configuration_obj.configuration
)
async def update_user_configuration_last_validated_at(self, user_id: int) -> None:
async with self.async_session() as session:

View file

@ -16,6 +16,7 @@ from api.db.models import (
)
from api.enums import CallType, StorageBackend
from api.schemas.workflow import WorkflowRunResponseSchema
from api.services.workflow.run_usage_response import format_public_cost_info
class WorkflowRunClient(BaseDBClient):
@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient):
"is_completed": run.is_completed,
"recording_url": run.recording_url,
"transcript_url": run.transcript_url,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info
and "dograh_token_usage" in run.cost_info
else round(
float(run.cost_info.get("total_cost_usd", 0)) * 100,
2,
)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds") or 0)
)
if run.cost_info
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(
run.cost_info, run.usage_info
),
"definition_id": run.definition_id,
"initial_context": run.initial_context,
"gathered_context": run.gathered_context,

View file

@ -89,6 +89,11 @@ class OrganizationConfigurationKey(Enum):
LANGFUSE_CREDENTIALS = (
"LANGFUSE_CREDENTIALS" # Org-level Langfuse tracing credentials
)
MODEL_CONFIGURATION_V2 = (
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
)
ORGANIZATION_PREFERENCES = "ORGANIZATION_PREFERENCES" # Org-level defaults such as timezone/test call number
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
class WorkflowStatus(Enum):

View file

@ -1,5 +1,5 @@
[project]
name = "dograh-api"
version = "1.33.0"
version = "1.34.0"
description = "Backend API for Dograh voice AI platform"
requires-python = ">=3.13,<3.14"

View file

@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect
from api.db import db_client
from api.enums import CallType, WorkflowRunState
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony import registry as telephony_registry
router = APIRouter(prefix="/agent-stream")
@ -67,19 +67,6 @@ async def agent_stream_websocket(
await websocket.close(code=1008, reason="Workflow not found")
return
quota_result = await check_dograh_quota_by_user_id(
workflow.user_id, workflow_id=workflow.id
)
if not quota_result.has_quota:
logger.warning(
f"agent-stream quota exceeded for user {workflow.user_id}: "
f"{quota_result.error_message}"
)
await websocket.close(
code=1008, reason=quota_result.error_message or "Quota exceeded"
)
return
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
workflow_run_name = f"WR-AGS-{numeric_suffix:08d}"
call_id = params.get("callId") or params.get("CallSid")
@ -108,6 +95,20 @@ async def agent_stream_websocket(
set_current_run_id(workflow_run.id)
set_current_org_id(workflow.organization_id)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"agent-stream quota exceeded for user {workflow.user_id}: "
f"{quota_result.error_message}"
)
await websocket.close(
code=1008, reason=quota_result.error_message or "Quota exceeded"
)
return
await db_client.update_workflow_run(
run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value
)

View file

@ -3,9 +3,12 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
from api.services.posthog_client import capture_event
from api.utils.auth import create_jwt_token, hash_password, verify_password
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
)
if mps_config:
await db_client.update_user_configuration(user.id, mps_config)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception:
logger.warning(
"Failed to create default configuration for OSS user", exc_info=True

View file

@ -18,7 +18,7 @@ from api.services.auth.depends import get_user
from api.services.campaign.runner import campaign_runner_service
from api.services.campaign.source_sync import CampaignSourceSyncService
from api.services.campaign.source_sync_factory import get_sync_service
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
from api.services.reports import generate_campaign_report_csv
from api.services.storage import storage_fs
@ -550,7 +550,10 @@ async def start_campaign(
# Check Dograh quota before starting campaign (apply per-workflow
# model_overrides so we evaluate the keys this campaign will use).
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
@ -872,7 +875,10 @@ async def resume_campaign(
# Check Dograh quota before resuming campaign (apply per-workflow
# model_overrides so we evaluate the keys this campaign will use).
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)

View file

@ -369,6 +369,10 @@ async def search_chunks(
try:
# Import here to avoid circular dependency
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
from api.services.configuration.registry import ServiceProviders
from api.services.gen_ai import (
AzureOpenAIEmbeddingService,
@ -376,20 +380,29 @@ async def search_chunks(
)
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
effective_config = resolved_config.effective
embeddings_api_key = None
embeddings_model = None
embeddings_provider = None
embeddings_base_url = None
embeddings_endpoint = None
embeddings_api_version = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
if effective_config.embeddings:
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
# Initialize embedding service based on provider
@ -406,9 +419,7 @@ async def search_chunks(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=getattr(user_config.embeddings, "base_url", None)
if user_config.embeddings
else None,
base_url=embeddings_base_url,
)
# Perform search

View file

@ -1,15 +1,27 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
from api.constants import (
DEFAULT_CAMPAIGN_RETRY_CONFIG,
DEFAULT_ORG_CONCURRENCY_LIMIT,
DEPLOYMENT_MODE,
)
from api.db import db_client
from api.db.models import UserModel
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
)
from api.schemas.organization_preferences import OrganizationPreferences
from api.schemas.telephony_config import (
TelephonyConfigRequest,
TelephonyConfigurationCreateRequest,
@ -26,8 +38,36 @@ from api.schemas.telephony_phone_number import (
PhoneNumberUpdateRequest,
ProviderSyncStatus,
)
from api.services.auth.depends import get_user
from api.services.configuration.masking import is_mask_of, mask_key
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.configuration.ai_model_configuration import (
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_organization_ai_model_configuration_v2,
get_resolved_ai_model_configuration,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_model_configurations_to_v2,
upsert_organization_ai_model_configuration_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
from api.services.configuration.registry import (
DOGRAH_STT_LANGUAGES,
REGISTRY,
ServiceProviders,
ServiceType,
)
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.organization_context import (
OrganizationContextResponse,
get_organization_context,
)
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
from api.services.posthog_client import capture_event
from api.services.telephony import registry as telephony_registry
from api.services.telephony.factory import get_telephony_provider_by_id
@ -98,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
telnyx_missing_webhook_public_key_count: int
@router.get("/context", response_model=OrganizationContextResponse)
async def get_current_organization_context(user: UserModel = Depends(get_user)):
"""Return organization-scoped configuration signals owned by Dograh."""
return await get_organization_context(user)
@router.get(
"/telephony-providers/metadata",
response_model=TelephonyProvidersMetadataResponse,
@ -159,6 +205,239 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
)
# ---------------------------------------------------------------------------
# AI model configurations v2
# ---------------------------------------------------------------------------
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
return {
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[service_type].items()
if provider != ServiceProviders.DOGRAH.value
}
async def _model_configuration_v2_response(
*,
user: UserModel,
configuration: OrganizationAIModelConfigurationV2 | None = None,
) -> OrganizationAIModelConfigurationResponse:
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
raw_configuration = (
configuration
if configuration is not None
else resolved.organization_configuration
)
return OrganizationAIModelConfigurationResponse(
configuration=mask_ai_model_configuration_v2(raw_configuration),
effective_configuration=mask_user_config(resolved.effective),
source=resolved.source,
)
@router.get("/model-configurations/v2/defaults")
async def get_model_configuration_v2_defaults(
user: UserModel = Depends(get_user_with_selected_organization),
):
byok_default_providers = {
service: provider
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
if provider != ServiceProviders.DOGRAH.value
}
return {
"dograh": {
"voices": [DOGRAH_DEFAULT_VOICE],
"speeds": list(DOGRAH_SPEED_OPTIONS),
"languages": DOGRAH_STT_LANGUAGES,
"defaults": {
"voice": DOGRAH_DEFAULT_VOICE,
"speed": 1.0,
"language": DOGRAH_DEFAULT_LANGUAGE,
},
},
"byok": {
"pipeline": {
"llm": _byok_provider_schemas(ServiceType.LLM),
"tts": _byok_provider_schemas(ServiceType.TTS),
"stt": _byok_provider_schemas(ServiceType.STT),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
"realtime": {
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
"llm": _byok_provider_schemas(ServiceType.LLM),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
},
}
@router.get(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def get_model_configuration_v2(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await _model_configuration_v2_response(user=user)
@router.put(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def save_model_configuration_v2(
request: OrganizationAIModelConfigurationV2,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
existing = await get_organization_ai_model_configuration_v2(organization_id)
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
try:
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/model-configurations/v2/migration-preview")
async def preview_model_configuration_v2_migration(
user: UserModel = Depends(get_user_with_selected_organization),
):
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc))
return {
"configuration": mask_ai_model_configuration_v2(configuration),
"effective_configuration": mask_user_config(
compile_ai_model_configuration_v2(configuration)
),
}
@router.post(
"/model-configurations/v2/migrate",
response_model=OrganizationAIModelConfigurationResponse,
)
async def migrate_model_configuration_v2(
force: bool = Query(default=False),
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
existing = await get_organization_ai_model_configuration_v2(organization_id)
if existing is not None and not force:
raise HTTPException(
status_code=409,
detail="Organization already has a v2 model configuration",
)
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
if DEPLOYMENT_MODE != "oss":
try:
await ensure_hosted_mps_billing_account_v2(
organization_id,
created_by=str(user.provider_id),
)
except Exception as exc:
logger.error(
"Failed to initialize MPS billing v2 account for organization {}: {}",
organization_id,
exc,
)
raise HTTPException(
status_code=502,
detail="Failed to initialize MPS billing v2 account",
)
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
await migrate_workflow_model_configurations_to_v2(
organization_id=organization_id,
fallback_user_config=legacy,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/preferences", response_model=OrganizationPreferences)
async def get_preferences(
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await get_organization_preferences(organization_id)
@router.put("/preferences", response_model=OrganizationPreferences)
async def save_preferences(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await upsert_organization_preferences(
organization_id,
request,
)
@router.get(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def get_model_configuration_preferences_legacy(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await get_preferences(user=user)
@router.put(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def save_model_configuration_preferences_legacy(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
return await save_preferences(request=request, user=user)
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
"""If the client re-submitted a masked sensitive field, restore the original."""
for field_name in _sensitive_fields(provider):

View file

@ -1,16 +1,16 @@
import json
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from loguru import logger
from pydantic import BaseModel, Field
from api.constants import DEPLOYMENT_MODE
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.mps_service_key_client import mps_service_key_client
from api.services.reports import generate_usage_runs_report_csv
from api.utils.artifacts import artifact_url
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
period_start: str
period_end: str
used_dograh_tokens: float
quota_dograh_tokens: int
percentage_used: float
next_refresh_date: str
quota_enabled: bool
total_duration_seconds: int
# New USD fields
used_amount_usd: Optional[float] = None
quota_amount_usd: Optional[float] = None
currency: Optional[str] = None
price_per_second_usd: Optional[float] = None
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
total_quota: float
class MPSCreditPurchaseUrlResponse(BaseModel):
checkout_url: str
class MPSBillingAccountResponse(BaseModel):
id: int
organization_id: int
billing_mode: str
cached_balance_credits: float
currency: str
class MPSCreditLedgerEntryResponse(BaseModel):
id: int
entry_type: str
origin: Optional[str] = None
credits_delta: float
balance_after: float
amount_minor: Optional[int] = None
amount_currency: Optional[str] = None
payment_order_id: Optional[int] = None
metric_code: Optional[str] = None
correlation_id: Optional[str] = None
aggregation_key: Optional[str] = None
usage_event_id: Optional[int] = None
workflow_run_id: Optional[int] = None
workflow_id: Optional[int] = None
billable_quantity: Optional[float] = None
quantity_unit: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: str
class MPSBillingCreditsResponse(BaseModel):
billing_version: Literal["legacy", "v2"]
total_credits_used: float = 0.0
remaining_credits: float = 0.0
total_quota: float = 0.0
account: Optional[MPSBillingAccountResponse] = None
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
total_count: int = 0
page: int = 1
limit: int = 50
total_pages: int = 0
def _optional_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
class WorkflowRunUsageResponse(BaseModel):
id: int
workflow_id: int
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
async def get_current_period_usage(user: UserModel = Depends(get_user)):
"""Get current billing period usage for the user's organization."""
"""Get current reporting-period usage for the user's organization."""
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
raise HTTPException(status_code=500, detail=str(e))
async def _get_mps_billing_account_status(
user: UserModel, organization_id: int
) -> Optional[dict]:
return await mps_service_key_client.get_billing_account_status(
organization_id=organization_id,
created_by=str(user.provider_id),
)
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
return bool(account and account.get("billing_mode") == "v2")
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
if DEPLOYMENT_MODE == "oss":
usage = await mps_service_key_client.get_usage_by_created_by(
str(user.provider_id)
)
else:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
usage = await mps_service_key_client.get_usage_by_organization(
user.selected_organization_id
)
total_used = float(usage.get("total_credits_used", 0.0))
total_remaining = float(usage.get("remaining_credits", 0.0))
return MPSBillingCreditsResponse(
billing_version="legacy",
total_credits_used=total_used,
remaining_credits=total_remaining,
total_quota=total_used + total_remaining,
)
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
async def get_billing_credits(
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
user: UserModel = Depends(get_user),
):
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
try:
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
return await _legacy_mps_credits_response(user)
organization_id = user.selected_organization_id
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
return await _legacy_mps_credits_response(user)
ledger = await mps_service_key_client.get_credit_ledger(
organization_id=organization_id,
page=page,
limit=limit,
created_by=str(user.provider_id),
)
account = ledger.get("account") or {}
ledger_entries = ledger.get("ledger_entries") or []
total_count = int(ledger.get("total_count") or len(ledger_entries))
response_limit = int(ledger.get("limit") or limit)
total_pages = int(
ledger.get("total_pages")
or ((total_count + response_limit - 1) // response_limit)
)
workflow_ids_by_run_id: dict[int, int] = {}
workflow_run_ids = {
workflow_run_id
for entry in ledger_entries
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
is not None
}
for workflow_run_id in workflow_run_ids:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if (
workflow_run
and workflow_run.workflow
and workflow_run.workflow.organization_id == organization_id
):
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
balance = float(account.get("cached_balance_credits") or 0.0)
total_debits = sum(
abs(float(entry.get("credits_delta") or 0.0))
for entry in ledger_entries
if float(entry.get("credits_delta") or 0.0) < 0
)
if ledger.get("total_debits_credits") is not None:
total_debits = float(ledger["total_debits_credits"])
return MPSBillingCreditsResponse(
billing_version="v2",
total_credits_used=total_debits,
remaining_credits=balance,
total_quota=balance + total_debits,
account=MPSBillingAccountResponse(
id=int(account["id"]),
organization_id=int(account["organization_id"]),
billing_mode=str(account["billing_mode"]),
cached_balance_credits=balance,
currency=str(account.get("currency") or "USD"),
),
ledger_entries=[
MPSCreditLedgerEntryResponse(
id=int(entry["id"]),
entry_type=str(entry["entry_type"]),
origin=entry.get("origin"),
credits_delta=float(entry.get("credits_delta") or 0.0),
balance_after=float(entry.get("balance_after") or 0.0),
amount_minor=entry.get("amount_minor"),
amount_currency=entry.get("amount_currency"),
payment_order_id=entry.get("payment_order_id"),
metric_code=entry.get("metric_code"),
correlation_id=entry.get("correlation_id"),
aggregation_key=entry.get("aggregation_key"),
usage_event_id=_optional_int(entry.get("usage_event_id")),
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
workflow_id=workflow_ids_by_run_id.get(
_optional_int(entry.get("workflow_run_id"))
)
if entry.get("workflow_run_id") is not None
else None,
billable_quantity=float(entry["billable_quantity"])
if entry.get("billable_quantity") is not None
else None,
quantity_unit=entry.get("quantity_unit"),
metadata=entry.get("metadata") or {},
created_at=str(entry["created_at"]),
)
for entry in ledger_entries
],
total_count=total_count,
page=int(ledger.get("page") or page),
limit=response_limit,
total_pages=total_pages,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to fetch billing credits: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post(
"/usage/mps-credits/purchase-url",
response_model=MPSCreditPurchaseUrlResponse,
)
async def create_mps_credit_purchase_url(
user: UserModel = Depends(get_user_with_selected_organization),
):
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
if DEPLOYMENT_MODE == "oss":
raise HTTPException(
status_code=404,
detail="Credit purchases are not available in OSS mode",
)
organization_id = user.selected_organization_id
assert organization_id is not None
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
raise HTTPException(
status_code=403,
detail=(
"Credit purchases are available only for organizations using billing v2"
),
)
try:
session = await mps_service_key_client.create_credit_purchase_url(
organization_id=organization_id,
created_by=str(user.provider_id),
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
billing_details={
"source": "dograh_billing",
"dograh_user_id": str(user.id),
"dograh_provider_id": str(user.provider_id),
},
)
except Exception as exc:
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
raise HTTPException(
status_code=502,
detail="Failed to create credit purchase URL",
)
checkout_url = session.get("checkout_url")
if not checkout_url:
logger.error(f"MPS checkout session response missing checkout_url: {session}")
raise HTTPException(
status_code=502,
detail="MPS checkout session response missing checkout_url",
)
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
FILTERS_DESCRIPTION = """\
JSON-encoded array of filter objects. Each object has the shape:

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel
from api.db import db_client
from api.enums import TriggerState, WorkflowStatus
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.factory import (
get_default_telephony_provider,
get_telephony_provider_by_id,
@ -179,14 +179,6 @@ async def _execute_resolved_target(
"""Shared execution path once the target workflow has been resolved."""
execution_user_id = _get_execution_user_id(target.workflow)
# Check Dograh quota using the workflow owner's config and model overrides.
quota_result = await check_dograh_quota_by_user_id(
execution_user_id,
workflow_id=target.workflow.id,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Get telephony provider — either the caller-specified config (validated
# against the workflow's org) or the org's default config.
if request.telephony_configuration_id is not None:
@ -268,6 +260,15 @@ async def _execute_resolved_target(
f"to phone number {request.phone_number}"
)
# Check Dograh quota after the run exists so hosted v2 can mint and store
# the MPS correlation id before the provider starts the call.
quota_result = await authorize_workflow_run_start(
workflow_id=target.workflow.id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# 9. Construct webhook URL for telephony provider callback
backend_endpoint, _ = await get_backend_endpoints()
webhook_endpoint = provider.WEBHOOK_ENDPOINT

View file

@ -7,6 +7,7 @@ They handle CORS, domain validation, and session management for embedded workflo
import secrets
from datetime import UTC, datetime, timedelta
from typing import Optional
from urllib.parse import urlsplit
from fastapi import (
APIRouter,
@ -16,6 +17,8 @@ from fastapi import (
)
from loguru import logger
from pydantic import BaseModel
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send
from api.db import db_client
from api.enums import WorkflowRunMode
@ -27,6 +30,9 @@ from api.routes.turn_credentials import (
router = APIRouter(prefix="/public/embed")
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
EMBED_CORS_MAX_AGE = "86400"
class InitEmbedRequest(BaseModel):
"""Request model for initializing an embed session"""
@ -70,11 +76,9 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
# If no domains specified, allow all origins
return True
# Extract domain from origin (remove protocol)
if "://" in origin:
domain = origin.split("://")[1].split("/")[0].split(":")[0]
else:
domain = origin
domain, origin_port = _parse_origin_host_port(origin)
if not domain:
return False
# Normalize domain for www matching
def normalize_www(d: str) -> tuple[str, str]:
@ -87,16 +91,23 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
domain_variants = normalize_www(domain)
for allowed in allowed_domains:
allowed = str(allowed).strip().lower()
if allowed == "*":
return True
elif allowed.startswith("*."):
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
if not allowed_domain:
continue
if allowed_port is not None and allowed_port != origin_port:
continue
if allowed_domain.startswith("*."):
# Wildcard subdomain matching
base_domain = allowed[2:]
base_domain = allowed_domain[2:]
if domain == base_domain or domain.endswith("." + base_domain):
return True
else:
# Check both www and non-www versions
allowed_variants = normalize_www(allowed)
allowed_variants = normalize_www(allowed_domain)
# If any variant of domain matches any variant of allowed, it's valid
if any(
dv in allowed_variants or av in domain_variants
@ -108,6 +119,24 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
return False
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
candidate = value.strip().lower()
if not candidate:
return "", None
if "://" not in candidate and not candidate.startswith("//"):
candidate = f"//{candidate}"
parsed = urlsplit(candidate)
try:
parsed_port = parsed.port
except ValueError:
parsed_port = None
port = str(parsed_port) if parsed_port is not None else None
return (parsed.hostname or "").rstrip("."), port
def generate_session_token() -> str:
"""Generate a cryptographically secure session token"""
return f"emb_session_{secrets.token_urlsafe(32)}"
@ -121,8 +150,120 @@ def get_request_origin(request: Request) -> str:
return origin
def _cors_response(origin: str, methods: str) -> Response:
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": methods,
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
"Vary": "Origin",
}
)
def _allow_embed_origin(response: Response, origin: str) -> None:
response.headers["Access-Control-Allow-Origin"] = origin
vary = response.headers.get("Vary")
if not vary:
response.headers["Vary"] = "Origin"
return
vary_values = {value.strip().lower() for value in vary.split(",")}
if "origin" not in vary_values:
response.headers["Vary"] = f"{vary}, Origin"
async def _config_preflight_response(token: str, origin: str) -> Response:
embed_token = await db_client.get_embed_token_by_token(token)
if not embed_token or not embed_token.is_active:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def _turn_credentials_preflight_response(
session_token: str, origin: str
) -> Response:
embed_session = await db_client.get_embed_session_by_token(session_token)
if not embed_session:
return Response(status_code=403)
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
return Response(status_code=403)
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
if not embed_token:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def build_public_embed_preflight_response(
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
) -> Response | None:
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
if path == f"{public_embed_prefix}/init":
if requested_method.upper() != "POST":
return Response(status_code=405)
return _cors_response(origin, "POST, OPTIONS")
config_prefix = f"{public_embed_prefix}/config/"
if path.startswith(config_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
token = path[len(config_prefix) :].split("/", 1)[0]
return await _config_preflight_response(token, origin)
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
if path.startswith(turn_credentials_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
return await _turn_credentials_preflight_response(session_token, origin)
return None
class PublicEmbedCORSMiddleware:
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
self.app = app
self.api_prefix = api_prefix
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin")
requested_method = headers.get("access-control-request-method")
if origin and requested_method:
response = await build_public_embed_preflight_response(
scope.get("path", ""), origin, requested_method, self.api_prefix
)
if response is not None:
await response(scope, receive, send)
return
await self.app(scope, receive, send)
@router.post("/init", response_model=InitEmbedResponse)
async def initialize_embed_session(request: Request, init_request: InitEmbedRequest):
async def initialize_embed_session(
request: Request, init_request: InitEmbedRequest, response: Response
):
"""Initialize an embed session with token validation and domain checking.
This endpoint:
@ -158,6 +299,9 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Create workflow run
try:
workflow_run = await db_client.create_workflow_run(
@ -204,8 +348,19 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
)
@router.options("/config/{token}")
async def options_embed_config(token: str, request: Request):
"""Fallback OPTIONS handler for the embed config endpoint.
Browser preflights include Access-Control-Request-Method and are handled by
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
OPTIONS requests on the same validation path.
"""
return await _config_preflight_response(token, request.headers.get("origin", ""))
@router.get("/config/{token}", response_model=EmbedConfigResponse)
async def get_embed_config(token: str, request: Request):
async def get_embed_config(token: str, request: Request, response: Response):
"""Get embed configuration without creating a session.
This endpoint is used to fetch widget configuration for display purposes
@ -226,6 +381,11 @@ async def get_embed_config(token: str, request: Request):
if not validate_origin(origin, embed_token.allowed_domains or []):
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
# Set CORS header explicitly; the global CORSMiddleware covers only
# first-party origins; this endpoint is fetched by external embed sites.
if origin:
_allow_embed_origin(response, origin)
# Extract settings with defaults
settings = embed_token.settings or {}
@ -243,24 +403,20 @@ async def get_embed_config(token: str, request: Request):
@router.options("/init")
async def options_init(request: Request):
"""Handle CORS preflight for init endpoint"""
"""Fallback OPTIONS handler for init endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
# For init endpoint, we need to check the token in the request body
# But OPTIONS requests don't have body, so we'll be permissive
# The actual validation happens in the POST request
origin = request.headers.get("origin", "*")
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Origin",
"Access-Control-Max-Age": "86400",
}
)
return _cors_response(origin, "POST, OPTIONS")
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
async def get_public_turn_credentials(session_token: str, request: Request):
async def get_public_turn_credentials(
session_token: str, request: Request, response: Response
):
"""Get TURN credentials for an embed session.
This endpoint allows embedded widgets to obtain TURN server credentials
@ -295,6 +451,9 @@ async def get_public_turn_credentials(session_token: str, request: Request):
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Check if TURN is configured
if not TURN_SECRET:
raise HTTPException(
@ -316,63 +475,8 @@ async def get_public_turn_credentials(session_token: str, request: Request):
@router.options("/turn-credentials/{session_token}")
async def options_turn_credentials(request: Request, session_token: str):
"""Handle CORS preflight for TURN credentials endpoint"""
origin = request.headers.get("origin", "*")
# Try to validate the session token and get allowed domains
allowed_origin = origin
try:
embed_session = await db_client.get_embed_session_by_token(session_token)
if embed_session:
embed_token = await db_client.get_embed_token_by_id(
embed_session.embed_token_id
)
if embed_token:
# Check if origin is in allowed domains (empty means allow all)
if validate_origin(origin, embed_token.allowed_domains or []):
allowed_origin = origin
else:
allowed_origin = ""
except Exception:
# On error, be permissive for OPTIONS
pass
return Response(
headers={
"Access-Control-Allow-Origin": allowed_origin,
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "86400",
}
)
@router.options("/config/{token}")
async def options_config(request: Request, token: str):
"""Handle CORS preflight for config endpoint"""
# Get origin header
origin = request.headers.get("origin", "*")
# Try to validate the token and get allowed domains
allowed_origin = origin
try:
embed_token = await db_client.get_embed_token_by_token(token)
if embed_token and embed_token.is_active:
# Check if origin is in allowed domains
if validate_origin(origin, embed_token.allowed_domains or []):
allowed_origin = origin
else:
# If not allowed, don't include the origin
allowed_origin = ""
except Exception:
# On error, be permissive for OPTIONS
pass
return Response(
headers={
"Access-Control-Allow-Origin": allowed_origin,
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "86400",
}
"""Fallback OPTIONS handler for TURN credentials endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
return await _turn_credentials_preflight_response(
session_token, request.headers.get("origin", "")
)

View file

@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
from api.errors.telephony_errors import TelephonyError
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.factory import (
get_all_telephony_providers,
@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
workflow_run_id: int | None = None
phone_number: str | None = None
# Optional explicit telephony config to use for the test call. If omitted,
# falls back to the user's per-user default (when set), then the org default.
# falls back to the org default.
telephony_configuration_id: int | None = None
# Optional caller-ID phone number to dial out from. Must belong to the
# resolved telephony configuration; otherwise the provider picks one.
@ -82,7 +82,12 @@ async def initiate_call(
"""Initiate a call using the configured telephony provider from web browser. This is
supposed to be a test call method for the draft version of the agent."""
user_configuration = await db_client.get_user_configurations(user.id)
from api.services.organization_preferences import get_organization_preferences
preferences = await get_organization_preferences(
user.selected_organization_id,
db=db_client,
)
# Resolve which telephony config to use: explicit request value, otherwise
# the org's default outbound config.
@ -116,13 +121,12 @@ async def initiate_call(
detail="telephony_not_configured",
)
phone_number = request.phone_number or user_configuration.test_phone_number
phone_number = request.phone_number or preferences.test_phone_number
if not phone_number:
raise HTTPException(
status_code=400,
detail="Phone number must be provided in request or set in user "
"configuration",
detail="Phone number must be provided in request or set in organization preferences",
)
workflow = await db_client.get_workflow(
@ -132,14 +136,6 @@ async def initiate_call(
raise HTTPException(status_code=404, detail="Workflow not found")
execution_user_id = _get_execution_user_id(workflow)
# Check Dograh quota before initiating the call (apply per-workflow
# model_overrides so the keys we will actually use are the ones checked).
quota_result = await check_dograh_quota_by_user_id(
execution_user_id, workflow_id=workflow.id
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Determine the workflow run mode based on provider type
workflow_run_mode = provider.PROVIDER_NAME
@ -182,6 +178,16 @@ async def initiate_call(
)
workflow_run_name = workflow_run.name
# Check Dograh quota after the run exists so hosted v2 can mint and store
# the MPS correlation id before initiating the call.
quota_result = await authorize_workflow_run_start(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Construct webhook URL based on provider type
backend_endpoint, _ = await get_backend_endpoints()
@ -735,19 +741,8 @@ async def handle_inbound_run(request: Request):
TelephonyError.SIGNATURE_VALIDATION_FAILED
)
# 4. Quota check (use the workflow's model_overrides if set).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# 5. Create workflow run + return provider-shaped response.
# 5. Create workflow run + authorize quota before returning provider
# stream instructions.
workflow_run_id = await _create_inbound_workflow_run(
workflow_id,
user_id,
@ -756,6 +751,17 @@ async def handle_inbound_run(request: Request):
telephony_configuration_id=telephony_configuration_id,
from_phone_number_id=phone_row.id,
)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
websocket_url = (
@ -870,20 +876,8 @@ async def handle_inbound_telephony(
logger.error(f"Request validation failed: {error_type}")
return provider_class.generate_validation_error_response(error_type)
# Check quota before processing (apply per-workflow model_overrides).
# Create workflow run.
user_id = workflow_context["user_id"]
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# Create workflow run
workflow_run_id = await _create_inbound_workflow_run(
workflow_id,
workflow_context["user_id"],
@ -892,6 +886,17 @@ async def handle_inbound_telephony(
telephony_configuration_id=workflow_context["telephony_configuration_id"],
from_phone_number_id=workflow_context.get("from_phone_number_id"),
)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# Generate response URLs
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()

View file

@ -10,6 +10,9 @@ from api.db.models import (
UserModel,
)
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
from api.services.configuration.check_validity import (
APIKeyStatusResponse,
UserConfigurationValidator,
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
from api.services.configuration.merge import merge_user_configurations
from api.services.configuration.registry import REGISTRY, ServiceType
from api.services.mps_service_key_client import mps_service_key_client
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
router = APIRouter(prefix="/user")
@ -94,8 +101,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
async def get_user_configurations(
user: UserModel = Depends(get_user),
) -> UserConfigurationRequestResponseSchema:
user_configurations = await db_client.get_user_configurations(user.id)
masked_config = mask_user_config(user_configurations)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
masked_config = mask_user_config(resolved_config.effective)
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -121,34 +137,61 @@ async def update_user_configurations(
# Remove organization_pricing from incoming dict as it's read-only
incoming_dict.pop("organization_pricing", None)
preferences_update = {
key: incoming_dict.pop(key)
for key in ("test_phone_number", "timezone")
if key in incoming_dict
}
# Merge via helper
try:
user_configurations = merge_user_configurations(existing_config, incoming_dict)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
if incoming_dict:
# Merge via helper
try:
user_configurations = merge_user_configurations(
existing_config, incoming_dict
)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
else:
user_configurations = existing_config
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
if user.selected_organization_id and preferences_update:
preferences = await get_organization_preferences(user.selected_organization_id)
if "test_phone_number" in preferences_update:
preferences.test_phone_number = preferences_update["test_phone_number"]
if "timezone" in preferences_update:
preferences.timezone = preferences_update["timezone"]
await upsert_organization_preferences(
user.selected_organization_id,
preferences,
)
# Return masked version of updated config
masked_config = mask_user_config(user_configurations)
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -168,7 +211,11 @@ async def validate_user_configurations(
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
user: UserModel = Depends(get_user),
) -> APIKeyStatusResponse:
configurations = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
configurations = resolved_config.effective
if (
configurations.last_validated_at

View file

@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import (
register_ws_sender,
unregister_ws_sender,
)
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
router = APIRouter(prefix="/ws")
@ -329,7 +329,11 @@ class SignalingManager:
# Check Dograh quota before initiating the call (apply per-workflow
# model_overrides so we evaluate the keys this workflow will use).
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
# Send error response for quota issues
await ws.send_json(

View file

@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
from api.db.models import UserModel
from api.db.workflow_template_client import WorkflowTemplateClient
from api.enums import CallType, PostHogEvent, StorageBackend
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
from api.schemas.workflow import WorkflowRunResponseSchema
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_resolved_ai_model_configuration,
merge_ai_model_configuration_v2_secrets,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import (
mask_workflow_configurations,
@ -32,12 +41,15 @@ from api.services.configuration.resolve import (
)
from api.services.mps_service_key_client import mps_service_key_client
from api.services.posthog_client import capture_event
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.reports import generate_workflow_report_csv
from api.services.storage import storage_fs
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
from api.services.workflow.duplicate import duplicate_workflow
from api.services.workflow.errors import ItemKind, WorkflowError
from api.services.workflow.run_usage_response import (
format_public_cost_info,
format_public_usage_info,
)
from api.services.workflow.trigger_paths import (
TriggerPathIssue,
ensure_trigger_paths,
@ -955,12 +967,74 @@ async def update_workflow(
existing_def,
)
# Validate model_overrides: resolve onto global config, then
# run the same validator used by the user-configurations endpoint.
# Also stamp the current global API key into the override so the override
# remains functional if the global config later switches to a different provider.
# Validate model overrides. v2 uses a complete workflow-level model
# configuration; legacy v1 uses partial service overlays.
workflow_configurations = request.workflow_configurations
if workflow_configurations and workflow_configurations.get("model_overrides"):
if workflow_configurations and workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
if existing_workflow is None:
raise HTTPException(
status_code=404, detail=f"Workflow with id {workflow_id} not found"
)
existing_draft = await db_client.get_draft_version(workflow_id)
existing_configs = (
existing_draft.workflow_configurations
if existing_draft
else existing_workflow.released_definition.workflow_configurations
)
existing_v2_override = (existing_configs or {}).get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
try:
incoming_v2_override = (
OrganizationAIModelConfigurationV2.model_validate(
workflow_configurations[
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
]
)
)
existing_v2_override_config = (
OrganizationAIModelConfigurationV2.model_validate(
existing_v2_override
)
if existing_v2_override
else None
)
v2_override = merge_ai_model_configuration_v2_secrets(
incoming_v2_override,
existing_v2_override_config,
)
if existing_v2_override_config is None:
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
v2_override = merge_ai_model_configuration_v2_secrets(
v2_override,
resolved_config.organization_configuration,
)
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
effective = compile_ai_model_configuration_v2(v2_override)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except (ValidationError, ValueError) as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
elif workflow_configurations and workflow_configurations.get("model_overrides"):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
@ -978,24 +1052,48 @@ async def update_workflow(
workflow_configurations,
existing_configs,
)
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
effective_config = resolved_config.effective
try:
enriched_overrides = enrich_overrides_with_api_keys(
workflow_configurations["model_overrides"],
user_config,
effective_config,
)
effective = resolve_effective_config(user_config, enriched_overrides)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
effective = resolve_effective_config(
effective_config, enriched_overrides
)
if resolved_config.source == "organization_v2":
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
await UserConfigurationValidator().validate(
compile_ai_model_configuration_v2(v2_override),
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
else:
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
if resolved_config.source == "organization_v2":
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
else:
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
# Reject upfront if any new trigger path collides with another
# workflow's trigger — keeps the workflow record from
@ -1171,22 +1269,7 @@ async def get_workflow_run(
"transcript_public_url": artifact_url(public_access_token, "transcript"),
"recording_public_url": artifact_url(public_access_token, "recording"),
"public_access_token": public_access_token,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info and "dograh_token_usage" in run.cost_info
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds"))
)
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
"usage_info": format_public_usage_info(run.usage_info),
"created_at": run.created_at,
"definition_id": run.definition_id,

View file

@ -9,8 +9,8 @@ from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota
from api.services.auth.depends import get_user_with_selected_organization
from api.services.quota_service import authorize_workflow_run_start
from api.services.workflow.text_chat_session_service import (
TextChatPendingTurnLostError,
TextChatSessionExecutionError,
@ -96,14 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
}
def _require_selected_organization_id(user: UserModel) -> int:
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
return user.selected_organization_id
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
async def _ensure_text_chat_quota(
user: UserModel,
workflow_id: int,
workflow_run_id: int,
) -> None:
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
@ -114,9 +116,8 @@ async def _load_text_session_or_404(
user: UserModel,
) -> WorkflowRunTextSessionModel:
set_current_run_id(run_id)
organization_id = _require_selected_organization_id(user)
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=organization_id
run_id, organization_id=user.selected_organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
@ -158,11 +159,8 @@ async def _execute_pending_turn_response(
async def create_text_chat_session(
workflow_id: int,
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
organization_id = _require_selected_organization_id(user)
await _ensure_text_chat_quota(user, workflow_id)
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
@ -172,12 +170,13 @@ async def create_text_chat_session(
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
organization_id=organization_id,
organization_id=user.selected_organization_id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
set_current_run_id(workflow_run.id)
await _ensure_text_chat_quota(user, workflow_id, workflow_run.id)
annotations = {
"tester": {
@ -220,7 +219,7 @@ async def create_text_chat_session(
async def get_text_chat_session(
workflow_id: int,
run_id: int,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@ -234,10 +233,10 @@ async def append_text_chat_message(
workflow_id: int,
run_id: int,
request: AppendTextChatMessageRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
await _ensure_text_chat_quota(user, workflow_id)
await _ensure_text_chat_quota(user, workflow_id, run_id)
try:
text_session = await append_text_chat_user_message(
@ -264,7 +263,7 @@ async def rewind_text_chat_session(
workflow_id: int,
run_id: int,
request: RewindTextChatSessionRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
try:

View file

@ -0,0 +1,198 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from api.services.configuration.registry import (
DograhEmbeddingsConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
EmbeddingsConfig,
LLMConfig,
RealtimeConfig,
ServiceProviders,
STTConfig,
TTSConfig,
)
DOGRAH_SPEED_OPTIONS: tuple[float, ...] = (0.8, 1.0, 1.2)
DOGRAH_DEFAULT_VOICE = "default"
DOGRAH_DEFAULT_LANGUAGE = "multi"
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
realtime: RealtimeConfig | None = None
is_realtime: bool = False
managed_service_version: int | None = None
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None
# Post-signup onboarding gate: set once the user submits or skips the
# onboarding form, so it shows only once per user.
onboarding_completed_at: datetime | None = None
onboarding_skipped: bool = False
@model_validator(mode="before")
@classmethod
def strip_incomplete_realtime_when_disabled(cls, data):
"""Skip realtime validation when is_realtime is False and api_key is missing."""
if isinstance(data, dict) and not data.get("is_realtime", False):
realtime = data.get("realtime")
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data
class DograhManagedAIModelConfiguration(BaseModel):
api_key: str
voice: str = DOGRAH_DEFAULT_VOICE
speed: float = Field(default=1.0)
language: str = DOGRAH_DEFAULT_LANGUAGE
@model_validator(mode="after")
def validate_speed(self):
if self.speed not in DOGRAH_SPEED_OPTIONS:
allowed = ", ".join(str(speed) for speed in DOGRAH_SPEED_OPTIONS)
raise ValueError(f"Dograh speed must be one of: {allowed}")
return self
class BYOKPipelineAIModelConfiguration(BaseModel):
llm: LLMConfig
tts: TTSConfig
stt: STTConfig
embeddings: EmbeddingsConfig | None = None
@model_validator(mode="after")
def reject_dograh_providers(self):
_reject_dograh_provider("llm", self.llm)
_reject_dograh_provider("tts", self.tts)
_reject_dograh_provider("stt", self.stt)
_reject_dograh_provider("embeddings", self.embeddings)
return self
class BYOKRealtimeAIModelConfiguration(BaseModel):
realtime: RealtimeConfig
llm: LLMConfig
embeddings: EmbeddingsConfig | None = None
@model_validator(mode="after")
def reject_dograh_providers(self):
_reject_dograh_provider("llm", self.llm)
_reject_dograh_provider("embeddings", self.embeddings)
return self
class BYOKAIModelConfiguration(BaseModel):
mode: Literal["pipeline", "realtime"]
pipeline: BYOKPipelineAIModelConfiguration | None = None
realtime: BYOKRealtimeAIModelConfiguration | None = None
@model_validator(mode="after")
def validate_selected_mode(self):
if self.mode == "pipeline" and self.pipeline is None:
raise ValueError("byok.pipeline is required when byok.mode is pipeline")
if self.mode == "realtime" and self.realtime is None:
raise ValueError("byok.realtime is required when byok.mode is realtime")
return self
class OrganizationAIModelConfigurationV2(BaseModel):
version: Literal[2] = 2
mode: Literal["dograh", "byok"]
dograh: DograhManagedAIModelConfiguration | None = None
byok: BYOKAIModelConfiguration | None = None
@model_validator(mode="after")
def validate_selected_mode(self):
if self.mode == "dograh" and self.dograh is None:
raise ValueError("dograh configuration is required when mode is dograh")
if self.mode == "byok" and self.byok is None:
raise ValueError("byok configuration is required when mode is byok")
return self
class OrganizationAIModelConfigurationResponse(BaseModel):
configuration: dict | None
effective_configuration: dict
source: Literal["organization_v2", "legacy_user_v1", "empty"]
def compile_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> EffectiveAIModelConfiguration:
if configuration.mode == "dograh":
if configuration.dograh is None:
raise ValueError("dograh configuration is required")
return _compile_dograh_configuration(configuration.dograh)
if configuration.byok is None:
raise ValueError("byok configuration is required")
if configuration.byok.mode == "pipeline":
if configuration.byok.pipeline is None:
raise ValueError("byok.pipeline is required")
pipeline = configuration.byok.pipeline
return EffectiveAIModelConfiguration(
llm=pipeline.llm,
tts=pipeline.tts,
stt=pipeline.stt,
embeddings=pipeline.embeddings,
is_realtime=False,
)
if configuration.byok.realtime is None:
raise ValueError("byok.realtime is required")
realtime = configuration.byok.realtime
return EffectiveAIModelConfiguration(
llm=realtime.llm,
realtime=realtime.realtime,
embeddings=realtime.embeddings,
is_realtime=True,
)
def _compile_dograh_configuration(
configuration: DograhManagedAIModelConfiguration,
) -> EffectiveAIModelConfiguration:
return EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
),
tts=DograhTTSService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
voice=configuration.voice,
speed=configuration.speed,
),
stt=DograhSTTService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
language=configuration.language,
),
embeddings=DograhEmbeddingsConfiguration(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
),
is_realtime=False,
managed_service_version=2,
)
def _reject_dograh_provider(section: str, service) -> None:
if service is None:
return
if getattr(service, "provider", None) == ServiceProviders.DOGRAH:
raise ValueError(f"BYOK {section} cannot use Dograh provider")

View file

@ -0,0 +1,6 @@
from pydantic import BaseModel
class OrganizationPreferences(BaseModel):
test_phone_number: str | None = None
timezone: str | None = None

View file

@ -1,37 +0,0 @@
from datetime import datetime
from pydantic import BaseModel, model_validator
from api.services.configuration.registry import (
EmbeddingsConfig,
LLMConfig,
RealtimeConfig,
STTConfig,
TTSConfig,
)
class UserConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
embeddings: EmbeddingsConfig | None = None
realtime: RealtimeConfig | None = None
is_realtime: bool = False
test_phone_number: str | None = None
timezone: str | None = None
last_validated_at: datetime | None = None
# Post-signup onboarding gate: set once the user submits or skips the
# onboarding form, so it shows only once per user (server-side, cross-device).
onboarding_completed_at: datetime | None = None
onboarding_skipped: bool = False
@model_validator(mode="before")
@classmethod
def strip_incomplete_realtime_when_disabled(cls, data):
"""Skip realtime validation when is_realtime is False and api_key is missing."""
if isinstance(data, dict) and not data.get("is_realtime", False):
realtime = data.get("realtime")
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data

View file

@ -1,7 +1,7 @@
from typing import Annotated, Optional
import httpx
from fastapi import Header, HTTPException, Query, WebSocket
from fastapi import Depends, Header, HTTPException, Query, WebSocket
from loguru import logger
from pydantic import ValidationError
@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.posthog_client import capture_event
from api.utils.auth import decode_jwt_token
@ -110,6 +111,19 @@ async def get_user(
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
try:
await ensure_hosted_mps_billing_account_v2(
organization.id,
created_by=str(stack_user["id"]),
)
except Exception:
logger.warning(
"Failed to initialize hosted MPS billing account for "
"organization {}",
organization.id,
exc_info=True,
)
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
@ -119,6 +133,19 @@ async def get_user(
await db_client.update_user_configuration(
user_model.id, mps_config
)
from api.enums import OrganizationConfigurationKey
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
mps_config
)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception as exc:
raise HTTPException(
@ -129,6 +156,14 @@ async def get_user(
return user_model
async def get_user_with_selected_organization(
user: Annotated[UserModel, Depends(get_user)],
) -> UserModel:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
return user
async def _handle_oss_auth(authorization: str | None) -> UserModel:
"""
Handle authentication for OSS deployment mode.
@ -192,7 +227,7 @@ async def _handle_api_key_auth(api_key: str) -> UserModel:
async def create_user_configuration_with_mps_key(
user_id: int, organization_id: int, user_provider_id: str
) -> Optional[UserConfiguration]:
) -> Optional[EffectiveAIModelConfiguration]:
"""Create user configuration using MPS service key.
Args:
@ -201,7 +236,7 @@ async def create_user_configuration_with_mps_key(
user_provider_id: The user's provider ID (for created_by field)
Returns:
UserConfiguration with MPS-provided API keys or None if failed
EffectiveAIModelConfiguration with MPS-provided API keys or None if failed
"""
async with httpx.AsyncClient() as client:
@ -211,7 +246,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
@ -229,7 +264,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users
@ -264,8 +299,8 @@ async def create_user_configuration_with_mps_key(
"model": "default",
},
}
user_config = UserConfiguration(**configuration)
return user_config
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"

View file

@ -15,6 +15,7 @@ from api.services.campaign.errors import (
PhoneNumberPoolExhaustedError,
)
from api.services.campaign.rate_limiter import rate_limiter
from api.services.quota_service import authorize_workflow_run_start
from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
},
)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
error_message = quota_result.error_message or "Quota exceeded"
logger.warning(
f"Campaign {campaign.id} quota check failed for workflow run "
f"{workflow_run.id}: {error_message}"
)
await db_client.update_workflow_run(
run_id=workflow_run.id,
is_completed=True,
state=WorkflowRunState.COMPLETED.value,
gathered_context={"error": error_message},
)
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
if mapping:
org_id, mapped_slot_id = mapping
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
workflow_run.id
)
if from_number_mapping:
fn_org_id, fn_number, fn_tcid = from_number_mapping
await rate_limiter.release_from_number(
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
)
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
raise ValueError(error_message)
# Initiate call via telephony provider
try:
# Construct webhook URL with parameters

View file

@ -0,0 +1,484 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Literal
from loguru import logger
from pydantic import ValidationError
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from api.constants import MPS_API_URL
from api.db import db_client
from api.db.models import WorkflowDefinitionModel, WorkflowModel
from api.enums import OrganizationConfigurationKey
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
BYOKAIModelConfiguration,
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,
mask_key,
resolve_masked_api_keys,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
@dataclass
class ResolvedAIModelConfiguration:
effective: EffectiveAIModelConfiguration
source: AIModelConfigurationSource
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
@dataclass
class WorkflowAIModelConfigurationMigrationResult:
workflow_count: int = 0
definition_count: int = 0
workflow_ids: list[int] | None = None
async def get_resolved_ai_model_configuration(
*,
user_id: int | None,
organization_id: int | None,
) -> ResolvedAIModelConfiguration:
organization_configuration = await get_organization_ai_model_configuration_v2(
organization_id
)
if organization_configuration is not None:
return ResolvedAIModelConfiguration(
effective=compile_ai_model_configuration_v2(organization_configuration),
source="organization_v2",
organization_configuration=organization_configuration,
)
if user_id is None:
return ResolvedAIModelConfiguration(
effective=EffectiveAIModelConfiguration(),
source="empty",
)
legacy = await db_client.get_user_configurations(user_id)
return ResolvedAIModelConfiguration(
effective=legacy,
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
)
async def get_effective_ai_model_configuration_for_workflow(
*,
user_id: int | None,
organization_id: int | None,
workflow_configurations: dict | None,
) -> EffectiveAIModelConfiguration:
workflow_configurations = workflow_configurations or {}
v2_override = workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
if v2_override:
return compile_ai_model_configuration_v2(
OrganizationAIModelConfigurationV2.model_validate(v2_override)
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user_id,
organization_id=organization_id,
)
return resolve_effective_config(
resolved_config.effective,
workflow_configurations.get("model_overrides"),
)
async def get_organization_ai_model_configuration_v2(
organization_id: int | None,
) -> OrganizationAIModelConfigurationV2 | None:
if organization_id is None:
return None
row = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
)
if row is None or not row.value:
return None
try:
return OrganizationAIModelConfigurationV2.model_validate(row.value)
except ValidationError as exc:
logger.warning(
"Invalid org AI model configuration v2 for organization "
f"{organization_id}: {exc}. Falling back to legacy configuration."
)
return None
async def upsert_organization_ai_model_configuration_v2(
organization_id: int,
configuration: OrganizationAIModelConfigurationV2,
) -> OrganizationAIModelConfigurationV2:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
configuration.model_dump(mode="json", exclude_none=True),
)
return configuration
async def migrate_workflow_model_configurations_to_v2(
*,
organization_id: int,
fallback_user_config: EffectiveAIModelConfiguration,
) -> WorkflowAIModelConfigurationMigrationResult:
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
workflow_updates: list[tuple[int, dict]] = []
definition_updates: list[tuple[int, dict]] = []
migrated_workflow_ids: set[int] = set()
for workflow in workflows:
base_config = fallback_user_config
if workflow.user_id is not None:
if workflow.user_id not in owner_configs:
owner_configs[
workflow.user_id
] = await db_client.get_user_configurations(workflow.user_id)
base_config = owner_configs[workflow.user_id]
workflow_configs, workflow_changed = (
migrate_workflow_configuration_model_override_to_v2(
workflow.workflow_configurations,
base_config,
)
)
if workflow_changed:
workflow_updates.append((workflow.id, workflow_configs))
migrated_workflow_ids.add(workflow.id)
for definition in workflow.definitions:
definition_configs, definition_changed = (
migrate_workflow_configuration_model_override_to_v2(
definition.workflow_configurations,
base_config,
)
)
if definition_changed:
definition_updates.append((definition.id, definition_configs))
migrated_workflow_ids.add(workflow.id)
if workflow_updates or definition_updates:
async with db_client.async_session() as session:
for workflow_id, workflow_configs in workflow_updates:
await session.execute(
update(WorkflowModel)
.where(WorkflowModel.id == workflow_id)
.values(workflow_configurations=workflow_configs)
)
for definition_id, definition_configs in definition_updates:
await session.execute(
update(WorkflowDefinitionModel)
.where(WorkflowDefinitionModel.id == definition_id)
.values(workflow_configurations=definition_configs)
)
await session.commit()
return WorkflowAIModelConfigurationMigrationResult(
workflow_count=len(migrated_workflow_ids),
definition_count=len(definition_updates),
workflow_ids=sorted(migrated_workflow_ids),
)
def migrate_workflow_configuration_model_override_to_v2(
workflow_configurations: dict | None,
base_config: EffectiveAIModelConfiguration,
) -> tuple[dict, bool]:
if not isinstance(workflow_configurations, dict):
return {}, False
migrated = copy.deepcopy(workflow_configurations)
model_overrides = migrated.get("model_overrides")
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
if not isinstance(model_overrides, dict):
if "model_overrides" in migrated:
migrated.pop("model_overrides", None)
return migrated, True
return migrated, False
if not existing_v2_override:
effective = resolve_effective_config(base_config, model_overrides)
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
mode="json", exclude_none=True
)
migrated.pop("model_overrides", None)
return migrated, True
def merge_ai_model_configuration_v2_secrets(
incoming: OrganizationAIModelConfigurationV2,
existing: OrganizationAIModelConfigurationV2 | None,
) -> OrganizationAIModelConfigurationV2:
if existing is None:
return incoming
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
existing_dict = existing.model_dump(mode="json", exclude_none=True)
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
incoming_dograh = incoming_dict.get("dograh") or {}
existing_dograh = existing_dict.get("dograh") or {}
incoming_key = incoming_dograh.get("api_key")
existing_key = existing_dograh.get("api_key")
if incoming_key and existing_key and contains_masked_key(incoming_key):
incoming_dograh["api_key"] = resolve_masked_api_keys(
incoming_key,
existing_key,
)
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
def check_for_masked_keys_in_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> None:
data = configuration.model_dump(mode="json", exclude_none=True)
_raise_if_masked_secret(data)
def mask_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2 | None,
) -> dict | None:
if configuration is None:
return None
data = configuration.model_dump(mode="json", exclude_none=True)
_mask_secret_fields(data)
return data
def convert_legacy_ai_model_configuration_to_v2(
configuration: EffectiveAIModelConfiguration,
) -> OrganizationAIModelConfigurationV2:
dograh_key = _first_dograh_api_key(configuration)
if dograh_key:
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
if configuration.is_realtime:
if configuration.realtime is None or configuration.llm is None:
raise ValueError("Realtime legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="realtime",
realtime=BYOKRealtimeAIModelConfiguration(
realtime=configuration.realtime,
llm=configuration.llm,
embeddings=configuration.embeddings,
),
),
)
if (
configuration.llm is None
or configuration.tts is None
or configuration.stt is None
):
raise ValueError("Pipeline legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="pipeline",
pipeline=BYOKPipelineAIModelConfiguration(
llm=configuration.llm,
tts=configuration.tts,
stt=configuration.stt,
embeddings=configuration.embeddings,
),
),
)
def dograh_embeddings_base_url() -> str:
return f"{MPS_API_URL}/api/v1/llm"
def apply_managed_embeddings_base_url(
*,
provider: str | None,
base_url: str | None,
) -> str | None:
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
return dograh_embeddings_base_url()
return base_url
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
return
incoming_mode = incoming_byok.get("mode")
existing_mode = existing_byok.get("mode")
if incoming_mode != existing_mode:
return
section_names = (
("llm", "tts", "stt", "embeddings")
if incoming_mode == "pipeline"
else ("realtime", "llm", "embeddings")
)
incoming_container = incoming_byok.get(incoming_mode)
existing_container = existing_byok.get(existing_mode)
if not isinstance(incoming_container, dict) or not isinstance(
existing_container, dict
):
return
for section_name in section_names:
incoming_section = incoming_container.get(section_name)
existing_section = existing_container.get(section_name)
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
_merge_service_secret_fields(incoming_section, existing_section)
async def _list_workflows_for_model_configuration_migration(
organization_id: int,
) -> list[WorkflowModel]:
async with db_client.async_session() as session:
result = await session.execute(
select(WorkflowModel)
.options(selectinload(WorkflowModel.definitions))
.where(WorkflowModel.organization_id == organization_id)
)
return list(result.scalars().unique().all())
def _merge_service_secret_fields(incoming: dict, existing: dict):
if (
incoming.get("provider") is not None
and existing.get("provider") is not None
and incoming.get("provider") != existing.get("provider")
):
return
for secret_field in SERVICE_SECRET_FIELDS:
if secret_field not in existing:
continue
incoming_secret = incoming.get(secret_field)
existing_secret = existing[secret_field]
if incoming_secret is None:
incoming[secret_field] = existing_secret
elif contains_masked_key(incoming_secret):
incoming[secret_field] = resolve_masked_api_keys(
incoming_secret,
existing_secret,
)
def _raise_if_masked_secret(value):
if isinstance(value, dict):
for key, nested in value.items():
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
raise ValueError(
f"The {key} appears to be masked. Please provide the actual "
"value, not the masked value."
)
_raise_if_masked_secret(nested)
elif isinstance(value, list):
for item in value:
_raise_if_masked_secret(item)
def _mask_secret_fields(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_secret_fields(nested)
elif isinstance(value, list):
for item in value:
_mask_secret_fields(item)
def _mask_secret_value(value):
if isinstance(value, list):
return [mask_key(item) for item in value]
return mask_key(value)
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
return any(
service is not None
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
)
)
def _convert_any_dograh_legacy_configuration(
configuration: EffectiveAIModelConfiguration,
dograh_key: str,
) -> OrganizationAIModelConfigurationV2:
speed = getattr(configuration.tts, "speed", 1.0)
if speed not in DOGRAH_SPEED_OPTIONS:
speed = 1.0
return OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key=dograh_key,
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
or DOGRAH_DEFAULT_VOICE,
speed=speed,
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
or DOGRAH_DEFAULT_LANGUAGE,
),
)
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
):
if service is None or _provider(service) != ServiceProviders.DOGRAH:
continue
try:
return _single_api_key(service)
except ValueError:
continue
return None
def _provider(service):
return getattr(service, "provider", None)
def _single_api_key(service) -> str:
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if len(keys) != 1:
raise ValueError("Expected exactly one API key")
return keys[0]
key = getattr(service, "api_key", None)
if not key:
raise ValueError("Expected an API key")
return key

View file

@ -8,8 +8,8 @@ from groq import Groq
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
UserConfiguration,
from api.schemas.ai_model_configuration import (
EffectiveAIModelConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
@ -64,7 +64,7 @@ class UserConfigurationValidator:
async def validate(
self,
configuration: UserConfiguration,
configuration: EffectiveAIModelConfiguration,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> APIKeyStatusResponse:
@ -75,21 +75,21 @@ class UserConfigurationValidator:
status_list = []
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
# Realtime is optional - only validate if is_realtime is enabled
if configuration.is_realtime:
status_list.extend(
self._validate_service(
configuration.realtime, "realtime", required=True
)
)
else:
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)

View file

@ -12,7 +12,7 @@ The rules are simple:
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceConfig
from api.services.integrations import get_node_secret_fields
@ -31,7 +31,7 @@ def contains_masked_key(value: str | list[str] | None) -> bool:
return any(MASK_MARKER in k for k in keys)
def check_for_masked_keys(config: "UserConfiguration") -> None:
def check_for_masked_keys(config: "EffectiveAIModelConfiguration") -> None:
"""Raise ValueError if any service in *config* still has a masked secret."""
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
service = getattr(config, field, None)
@ -111,7 +111,7 @@ def resolve_masked_api_keys(
# ---------------------------------------------------------------------------
# High-level helpers for UserConfiguration objects
# High-level helpers for EffectiveAIModelConfiguration objects
# ---------------------------------------------------------------------------
@ -129,7 +129,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
return data
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
def mask_user_config(config: EffectiveAIModelConfiguration) -> Dict[str, Any]:
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
return {
@ -155,21 +155,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
masked = copy.deepcopy(config)
model_overrides = masked.get("model_overrides")
if not isinstance(model_overrides, dict):
return masked
if isinstance(model_overrides, dict):
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
v2_override = masked.get("model_configuration_v2_override")
if isinstance(v2_override, dict):
_mask_nested_service_secrets(v2_override)
return masked
def _mask_nested_service_secrets(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_nested_service_secrets(nested)
elif isinstance(value, list):
for item in value:
_mask_nested_service_secrets(item)
# ---------------------------------------------------------------------------
# Workflow definition helpers mask / merge node API keys
# ---------------------------------------------------------------------------

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
import copy
from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,
@ -66,9 +66,9 @@ def _merge_service_secret_fields(
def merge_user_configurations(
existing: UserConfiguration, incoming_partial: Dict[str, dict]
) -> UserConfiguration:
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
existing: EffectiveAIModelConfiguration, incoming_partial: Dict[str, dict]
) -> EffectiveAIModelConfiguration:
"""Merge *incoming_partial* onto *existing* and return a new EffectiveAIModelConfiguration.
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
extracted via Pydantic `model_dump`).
@ -113,14 +113,14 @@ def merge_user_configurations(
if "timezone" in incoming_partial:
merged["timezone"] = incoming_partial["timezone"]
# Onboarding gate flags — overwrite only when supplied (set once on submit/skip).
# Onboarding gate flags: overwrite only when supplied.
if "onboarding_completed_at" in incoming_partial:
merged["onboarding_completed_at"] = incoming_partial["onboarding_completed_at"]
if "onboarding_skipped" in incoming_partial:
merged["onboarding_skipped"] = incoming_partial["onboarding_skipped"]
return UserConfiguration.model_validate(merged)
return EffectiveAIModelConfiguration.model_validate(merged)
def merge_workflow_configuration_secrets(

View file

@ -911,7 +911,7 @@ class DograhTTSService(BaseTTSConfiguration):
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice.")
CARTESIA_TTS_MODELS = ["sonic-3"]
CARTESIA_TTS_MODELS = ["sonic-3.5", "sonic-3"]
@register_tts
@ -919,7 +919,7 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
model: str = Field(
default="sonic-3",
default="sonic-3.5",
description="Cartesia TTS model.",
json_schema_extra={"examples": CARTESIA_TTS_MODELS},
)
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
)
DOGRAH_EMBEDDING_MODELS = ["default"]
@register_embeddings
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: str = Field(
default="default",
description="Dograh-managed embedding model.",
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
)
EmbeddingsConfig = Annotated[
Union[
OpenAIEmbeddingsConfiguration,
OpenRouterEmbeddingsConfiguration,
AzureOpenAIEmbeddingsConfiguration,
DograhEmbeddingsConfiguration,
],
Field(discriminator="provider"),
]

View file

@ -4,13 +4,13 @@ from __future__ import annotations
import copy
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
REGISTRY,
ServiceType,
)
# Maps override key → (UserConfiguration field, ServiceType for registry lookup)
# Maps override key → (EffectiveAIModelConfiguration field, ServiceType for registry lookup)
_SECTION_MAP: dict[str, ServiceType] = {
"llm": ServiceType.LLM,
"tts": ServiceType.TTS,
@ -36,7 +36,7 @@ _SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
def enrich_overrides_with_api_keys(
model_overrides: dict,
user_config: UserConfiguration,
user_config: EffectiveAIModelConfiguration,
) -> dict:
"""Copy API keys from the global config into model_overrides where missing.
@ -74,9 +74,9 @@ def enrich_overrides_with_api_keys(
def resolve_effective_config(
user_config: UserConfiguration,
user_config: EffectiveAIModelConfiguration,
model_overrides: dict | None,
) -> UserConfiguration:
) -> EffectiveAIModelConfiguration:
"""Deep-merge workflow model_overrides onto global user config.
- If model_overrides is None or empty, returns a copy of user_config unchanged.

View file

@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the OpenAI embedding service.
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
field_name="base_url",
)
client_kwargs["base_url"] = base_url
if default_headers:
client_kwargs["default_headers"] = default_headers
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:

View file

@ -0,0 +1,78 @@
from __future__ import annotations
from typing import Any
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
def uses_managed_model_services_v2(
ai_model_config: EffectiveAIModelConfiguration | None,
) -> bool:
if (
ai_model_config is None
or getattr(ai_model_config, "managed_service_version", None) != 2
):
return False
return any(
_is_dograh_service(getattr(ai_model_config, section_name, None))
for section_name in ("llm", "tts", "stt", "embeddings")
)
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
if not initial_context:
return None
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
if correlation_id is None:
return None
return str(correlation_id)
async def ensure_mps_correlation_id(
*,
ai_model_config: EffectiveAIModelConfiguration,
workflow_run_id: int,
initial_context: dict[str, Any] | None,
) -> str | None:
existing = get_mps_correlation_id(initial_context)
if existing:
return existing
if not uses_managed_model_services_v2(ai_model_config):
return None
raise ValueError(
"Managed model services v2 requires workflow run authorization before "
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
)
def _is_dograh_service(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):
service = getattr(ai_model_config, section_name, None)
if not _is_dograh_service(service):
continue
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if keys:
return keys[0]
api_key = getattr(service, "api_key", None)
if isinstance(api_key, str) and api_key:
return api_key
return None

View file

@ -0,0 +1,23 @@
from typing import Optional
from api.constants import DEPLOYMENT_MODE
from api.services.mps_service_key_client import mps_service_key_client
async def ensure_hosted_mps_billing_account_v2(
organization_id: int,
*,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Ensure hosted orgs have an MPS billing v2 account.
OSS deployments use legacy per-key quota accounting and do not create MPS
billing accounts.
"""
if DEPLOYMENT_MODE == "oss":
return None
return await mps_service_key_client.ensure_billing_account_v2(
organization_id=organization_id,
created_by=created_by,
)

View file

@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
Service keys are stored and managed entirely in MPS, not in the local database.
"""
import asyncio
from typing import List, Optional
import httpx
@ -353,6 +354,278 @@ class MPSServiceKeyClient:
response=response,
)
async def create_credit_purchase_url(
self,
organization_id: int,
created_by: Optional[str] = None,
return_url: Optional[str] = None,
billing_details: Optional[dict] = None,
) -> dict:
"""Create a short-lived MPS checkout URL for adding organization credits."""
payload = {
"created_by": created_by,
"return_url": return_url,
"billing_details": billing_details or {},
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create MPS credit purchase URL: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create MPS credit purchase URL: {response.text}",
request=response.request,
response=response,
)
async def get_credit_ledger(
self,
organization_id: int,
page: int = 1,
limit: int = 50,
created_by: Optional[str] = None,
) -> dict:
"""Get the MPS v2 billing account balance and recent credit ledger."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
params={"page": page, "limit": limit},
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS credit ledger: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS credit ledger: {response.text}",
request=response.request,
response=response,
)
async def get_billing_account_status(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Get an existing MPS v2 billing account without creating one."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS billing account status: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS billing account status: {response.text}",
request=response.request,
response=response,
)
async def ensure_billing_account_v2(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Create or return the MPS v2 billing account for an organization."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to ensure MPS billing account v2: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to ensure MPS billing account v2: {response.text}",
request=response.request,
response=response,
)
async def authorize_workflow_run_start(
self,
*,
organization_id: int,
workflow_run_id: int | None = None,
service_key: Optional[str] = None,
require_correlation_id: bool = False,
minimum_credits: float | None = None,
metadata: Optional[dict] = None,
created_by: Optional[str] = None,
) -> dict:
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
payload = {
"workflow_run_id": workflow_run_id,
"service_key": service_key,
"require_correlation_id": require_correlation_id,
"metadata": metadata or {},
}
if minimum_credits is not None:
payload["minimum_credits"] = minimum_credits
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to authorize MPS workflow run start: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to authorize MPS workflow run start: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,
service_key: str,
workflow_run_id: int | None = None,
) -> dict:
"""Mint a server-generated correlation ID for managed model services."""
payload: dict[str, int] = {}
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
json=payload,
headers={
"Authorization": f"Bearer {service_key}",
"Content-Type": "application/json",
},
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create correlation ID: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create correlation ID: {response.text}",
request=response.request,
response=response,
)
async def report_platform_usage(
self,
*,
organization_id: int,
correlation_id: Optional[str] = None,
duration_seconds: Optional[float] = None,
workflow_run_id: int | None = None,
metadata: Optional[dict] = None,
max_attempts: int = 3,
) -> dict:
"""Report hosted Dograh platform usage for a completed workflow run."""
if DEPLOYMENT_MODE == "oss":
raise ValueError("OSS deployments must not report platform usage to MPS")
if not correlation_id and duration_seconds is None:
raise ValueError(
"Platform usage reports require correlation_id or duration_seconds"
)
payload: dict = {
"metadata": metadata or {},
}
if correlation_id:
payload["correlation_id"] = correlation_id
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
max_attempts = max(1, max_attempts)
last_response: httpx.Response | None = None
async with httpx.AsyncClient(timeout=self.timeout) as client:
for attempt in range(1, max_attempts + 1):
response = await client.post(
(
f"{self.base_url}/api/v1/billing/accounts/"
f"{organization_id}/platform-usage"
),
json=payload,
headers=self._get_headers(organization_id=organization_id),
)
last_response = response
if response.status_code == 200:
return response.json()
should_retry = (
response.status_code == 409
and "usage_not_ready" in response.text
and attempt < max_attempts
)
if should_retry:
await asyncio.sleep(attempt)
continue
logger.error(
"Failed to report platform usage: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to report platform usage: {response.text}",
request=response.request,
response=response,
)
raise httpx.HTTPStatusError(
"Failed to report platform usage",
request=last_response.request,
response=last_response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -0,0 +1,50 @@
from typing import Literal, Optional
from pydantic import BaseModel
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
class OrganizationModelServicesContext(BaseModel):
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
has_model_configuration_v2: bool
managed_service_version: Optional[int] = None
uses_managed_service_v2: bool
class OrganizationContextResponse(BaseModel):
organization_id: Optional[int] = None
organization_provider_id: Optional[str] = None
model_services: OrganizationModelServicesContext
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
organization_id = user.selected_organization_id
organization = (
await db_client.get_organization_by_id(organization_id)
if organization_id
else None
)
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=organization_id,
)
managed_service_version = resolved.effective.managed_service_version
return OrganizationContextResponse(
organization_id=organization_id,
organization_provider_id=organization.provider_id if organization else None,
model_services=OrganizationModelServicesContext(
config_source=resolved.source,
has_model_configuration_v2=resolved.source == "organization_v2",
managed_service_version=managed_service_version,
uses_managed_service_v2=(
resolved.source == "organization_v2" and managed_service_version == 2
),
),
)

View file

@ -0,0 +1,62 @@
from inspect import isawaitable
from loguru import logger
from pydantic import ValidationError
from api.db import db_client
from api.enums import OrganizationConfigurationKey
from api.schemas.organization_preferences import OrganizationPreferences
async def get_organization_preferences(
organization_id: int | None,
db=None,
) -> OrganizationPreferences:
if organization_id is None:
return OrganizationPreferences()
db = db or db_client
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
)
if row is None:
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
return _parse_preferences(row.value if row is not None else None, organization_id)
async def upsert_organization_preferences(
organization_id: int,
preferences: OrganizationPreferences,
) -> OrganizationPreferences:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
preferences.model_dump(mode="json", exclude_none=True),
)
return preferences
async def _get_configuration(db, organization_id: int, key: str):
row = db.get_configuration(organization_id, key)
if isawaitable(row):
row = await row
return row
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
if not value or not isinstance(value, dict):
return OrganizationPreferences()
try:
return OrganizationPreferences.model_validate(value)
except ValidationError as exc:
logger.warning(
"Invalid organization preferences for organization "
f"{organization_id}: {exc}. Returning defaults."
)
return OrganizationPreferences()

View file

@ -15,6 +15,29 @@ from api.utils.credential_auth import build_auth_header
PRE_CALL_FETCH_TIMEOUT_SECONDS = 10
def _extract_initial_context(response_data: Dict[str, Any]) -> Dict[str, Any]:
"""Pull the context variables out of a pre-call fetch response.
The canonical key is ``initial_context``. The legacy ``dynamic_variables``
key is still accepted for backward compatibility, so existing endpoints
keep working; ``initial_context`` takes precedence when both are present.
Either key may appear at the top level or nested under ``call_inbound``:
{"call_inbound": {"initial_context": {...}}} | {"initial_context": {...}}
{"call_inbound": {"dynamic_variables": {...}}} | {"dynamic_variables": {...}}
"""
container = response_data.get("call_inbound")
if not isinstance(container, dict):
container = response_data
for key in ("initial_context", "dynamic_variables"):
value = container.get(key)
if isinstance(value, dict):
return value
return {}
async def execute_pre_call_fetch(
*,
url: str,
@ -77,24 +100,16 @@ async def execute_pre_call_fetch(
)
return {}
# Extract dynamic_variables from Retell-compatible response
# Supports: {call_inbound: {dynamic_variables: {...}}}
# or: {dynamic_variables: {...}}
dynamic_vars = {}
call_inbound = response_data.get("call_inbound")
if isinstance(call_inbound, dict):
dynamic_vars = call_inbound.get("dynamic_variables", {})
elif "dynamic_variables" in response_data:
dynamic_vars = response_data["dynamic_variables"]
if not isinstance(dynamic_vars, dict):
dynamic_vars = {}
# Extract the variables to merge into initial_context. Prefers
# the canonical `initial_context` key, falling back to the
# legacy `dynamic_variables` key for backward compatibility.
initial_context_vars = _extract_initial_context(response_data)
logger.info(
f"Pre-call fetch: success ({response.status_code}), "
f"dynamic_variables keys: {list(dynamic_vars.keys())}"
f"initial_context keys: {list(initial_context_vars.keys())}"
)
return dynamic_vars
return initial_context_vars
else:
logger.warning(
f"Pre-call fetch: HTTP {response.status_code} - "

View file

@ -162,15 +162,13 @@ async def run_pipeline_telephony(
workflow_id: Workflow being executed.
workflow_run_id: Workflow run row.
user_id: Owner of the workflow.
call_id: Provider call identifier (stored in cost_info for billing).
call_id: Provider call identifier.
transport_kwargs: Provider-specific kwargs forwarded to the transport
factory (e.g. stream_sid + call_sid for Twilio).
"""
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
set_current_run_id(workflow_run_id)
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
workflow = await db_client.get_workflow(workflow_id, user_id)
if workflow:
set_current_org_id(workflow.organization_id)
@ -195,14 +193,17 @@ async def run_pipeline_telephony(
# Resolve effective user config here so the transport can tune its
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
# values into _run_pipeline so it doesn't fetch them again.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -272,15 +273,18 @@ async def run_pipeline_smallwebrtc(
# Resolve workflow_run + effective user_config here so the transport can
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
# reuses these via kwargs so we don't fetch twice.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -334,7 +338,7 @@ async def _run_pipeline(
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
merged_call_context_vars = dict(workflow_run.initial_context or {})
# If there is some extra call_context_vars, fold them in. Persistence
# happens once below, after runtime_configuration is also resolved.
if call_context_vars:
@ -380,15 +384,31 @@ async def _run_pipeline(
# Resolve model overrides from the version onto global user config (skip
# when the caller already resolved it).
if resolved_user_config is None:
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
else:
user_config = resolved_user_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=merged_call_context_vars,
)
if mps_correlation_id:
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
is_realtime = user_config.is_realtime and user_config.realtime is not None
@ -400,11 +420,23 @@ async def _run_pipeline(
# Realtime services don't implement run_inference, so create a
# separate text LLM for variable extraction and other out-of-band
# inference calls.
inference_llm = create_llm_service(user_config)
inference_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
stt = create_stt_service(
user_config,
audio_config,
keyterms=keyterms,
correlation_id=mps_correlation_id,
)
tts = create_tts_service(
user_config,
audio_config,
correlation_id=mps_correlation_id,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = None
# Stamp the providers/models actually resolved for this run onto
@ -508,10 +540,17 @@ async def _run_pipeline(
embeddings_endpoint = None
embeddings_api_version = None
if user_config and user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
@ -679,7 +718,10 @@ async def _run_pipeline(
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
voicemail_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),

View file

@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
user_config,
audio_config: "AudioConfig",
keyterms: list[str] | None = None,
correlation_id: str | None = None,
):
"""Create and return appropriate STT service based on user configuration
@ -160,6 +163,7 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
correlation_id=correlation_id,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
@ -286,7 +290,9 @@ def create_stt_service(
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_tts_service(
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
):
"""Create and return appropriate TTS service based on user configuration
Args:
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
correlation_id=correlation_id,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
model: str,
api_key: str | None,
*,
correlation_id: str | None = None,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=api_key,
correlation_id=correlation_id,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
def create_llm_service(user_config, correlation_id: str | None = None):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
@ -880,4 +889,10 @@ def create_llm_service(user_config):
elif provider == ServiceProviders.SARVAM.value:
kwargs["temperature"] = user_config.llm.temperature
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
return create_llm_service_from_provider(
provider,
model,
api_key,
correlation_id=correlation_id,
**kwargs,
)

View file

@ -1,76 +0,0 @@
# Pricing Module
This module contains pricing models and registries for different AI services used in workflow cost calculations.
## Structure
```
pricing/
├── __init__.py # Main module exports
├── models.py # Base pricing model classes
├── llm.py # LLM pricing configurations
├── tts.py # TTS pricing configurations
├── stt.py # STT pricing configurations
├── registry.py # Combined pricing registry
└── README.md # This file
```
## Pricing Models
### TokenPricingModel
Used for LLM services that charge based on tokens:
- `prompt_token_price`: Cost per prompt token
- `completion_token_price`: Cost per completion token
- `cache_read_discount`: Discount for cache read tokens (default 50%)
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
### CharacterPricingModel
Used for TTS services that charge based on character count:
- `character_price`: Cost per character
### TimePricingModel
Used for STT services that charge based on time:
- `second_price`: Cost per second
## Adding New Pricing
### Adding a New LLM Model
Edit `llm.py` and add the model to the appropriate provider:
```python
ServiceProviders.OPENAI: {
"new-model": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000,
completion_token_price=Decimal("8.00") / 1000000,
),
# ... existing models
}
```
### Adding a New Provider
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
2. The registry will automatically include them
### Adding a New Service Type
1. Create a new pricing file (e.g., `image.py`)
2. Define the pricing models
3. Import and add to `registry.py`
## Usage
The pricing registry is automatically imported and used by the cost calculator:
```python
from api.services.pricing import PRICING_REGISTRY
from api.services.workflow.cost_calculator import cost_calculator
# The cost calculator uses the pricing registry automatically
result = cost_calculator.calculate_total_cost(usage_info)
```
## Maintenance
- Update pricing when providers change their rates
- All prices should use `Decimal` for precision
- Include comments with current pricing from provider documentation
- Test changes with existing test suite

View file

@ -1,9 +0,0 @@
"""
Pricing module for workflow cost calculation.
This module contains pricing models and registries for different AI services.
"""
from .registry import PRICING_REGISTRY
__all__ = ["PRICING_REGISTRY"]

View file

@ -1,228 +0,0 @@
"""
Cost Calculator for Workflow Runs
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
Features:
- Token-based pricing for LLM services with cache optimization support
- Character-based pricing for TTS services
- Time-based pricing for STT services
- Configurable pricing models that can be updated
- Support for multiple providers and models
- Automatic provider inference from model names
- JSON serialization support for database storage
Usage:
from api.tasks.cost_calculator import cost_calculator
usage_info = {
"llm": {
"processor_name|||gpt-4o": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"total_tokens": 1500,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0
}
},
"tts": {
"processor_name|||aura-2-helena-en": 2000 # character count
}
}
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
print(f"Total cost: ${cost_breakdown['total']:.6f}")
"""
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
from api.services.configuration.registry import ServiceProviders
from api.services.pricing import PRICING_REGISTRY
from api.services.pricing.models import (
PricingModel,
)
class CostCalculator:
"""Main cost calculator class"""
def __init__(self, pricing_registry: Dict = None):
self.pricing_registry = pricing_registry or PRICING_REGISTRY
def get_pricing_model(
self, service_type: str, provider: str, model: str
) -> Optional[PricingModel]:
"""Get pricing model for a specific service, provider, and model"""
try:
service_pricing = self.pricing_registry.get(service_type, {})
# Try to get pricing for the specific provider
provider_pricing = service_pricing.get(provider, {})
pricing_model = provider_pricing.get(model) or provider_pricing.get(
"default"
)
if pricing_model:
return pricing_model
# If not found, try the "default" provider for this service type
default_provider_pricing = service_pricing.get("default", {})
return default_provider_pricing.get(model) or default_provider_pricing.get(
"default"
)
except (KeyError, AttributeError):
return None
def calculate_llm_cost(
self, provider: str, model: str, usage: Dict[str, int]
) -> Decimal:
"""Calculate cost for LLM usage"""
pricing_model = self.get_pricing_model("llm", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(usage)
def calculate_tts_cost(
self, provider: str, model: str, character_count: int
) -> Decimal:
"""Calculate cost for TTS usage"""
pricing_model = self.get_pricing_model("tts", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(character_count)
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
"""Calculate cost for STT usage"""
pricing_model = self.get_pricing_model("stt", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(seconds)
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
llm_cost_total = Decimal("0")
tts_cost_total = Decimal("0")
stt_cost_total = Decimal("0")
# Calculate LLM costs
llm_usage = usage_info.get("llm", {})
for key, usage in llm_usage.items():
processor, model = self._parse_key(key)
# Try to determine provider from processor name or model
provider = self._infer_provider_from_model(model, "llm")
cost = self.calculate_llm_cost(provider, model, usage)
llm_cost_total += cost
# Calculate TTS costs
tts_usage = usage_info.get("tts", {})
for key, character_count in tts_usage.items():
processor, model = self._parse_key(key)
# Handle the case where model is "None" - infer from processor
if model.lower() in ["none", "null", ""]:
provider = self._infer_provider_from_processor(processor, "tts")
model = "default" # Use default model for the provider
else:
provider = self._infer_provider_from_model(model, "tts")
cost = self.calculate_tts_cost(provider, model, character_count)
tts_cost_total += cost
# Calculate STT costs from explicit stt usage
stt_usage = usage_info.get("stt", {})
for key, seconds in stt_usage.items():
processor, model = self._parse_key(key)
provider = self._infer_provider_from_model(model, "stt")
cost = self.calculate_stt_cost(provider, model, seconds)
stt_cost_total += cost
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
return {
"llm_cost": float(llm_cost_total),
"tts_cost": float(tts_cost_total),
"stt_cost": float(stt_cost_total),
"total": float(total_cost),
}
def _parse_key(self, key) -> Tuple[str, str]:
"""Parse key which is in format 'processor|||model'"""
if isinstance(key, str) and "|||" in key:
parts = key.split("|||", 1)
return parts[0], parts[1]
else:
# Fallback for backwards compatibility or malformed keys
return str(key), "unknown"
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
"""Infer provider from model name"""
if not model:
return "unknown"
model_lower = model.lower()
# OpenAI models
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
return ServiceProviders.OPENAI
# Groq models
if any(keyword in model_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Elevenlabs models
if any(keyword in model_lower for keyword in ["eleven"]):
return ServiceProviders.ELEVENLABS
# Deepgram models
if any(
keyword in model_lower
for keyword in ["deepgram", "nova", "phonecall", "general"]
):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
"""Infer provider from processor name"""
if not processor:
return "unknown"
processor_lower = processor.lower()
# OpenAI processors
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
return ServiceProviders.OPENAI
# Groq processors
if any(keyword in processor_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Deepgram processors
if any(keyword in processor_lower for keyword in ["deepgram"]):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def update_pricing(
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
):
"""Update pricing for a specific service/provider/model combination"""
if service_type not in self.pricing_registry:
self.pricing_registry[service_type] = {}
if provider not in self.pricing_registry[service_type]:
self.pricing_registry[service_type][provider] = {}
self.pricing_registry[service_type][provider][model] = pricing_model
# Global cost calculator instance
cost_calculator = CostCalculator()

View file

@ -1,44 +0,0 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -1,143 +0,0 @@
"""
LLM pricing models for different providers.
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TokenPricingModel
# LLM pricing registry
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-3.5-turbo": TokenPricingModel(
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
),
"gpt-4": TokenPricingModel(
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
),
"gpt-4.1": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
),
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
),
"gpt-4.1-nano": TokenPricingModel(
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
),
"gpt-4.5-preview": TokenPricingModel(
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
),
"gpt-4o": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
completion_token_price=Decimal("10.00")
/ 1000000, # $10.00 per 1M tokens - FIXED
),
"gpt-4o-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
),
"gpt-4o-mini": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
),
"gpt-4o-search-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-search-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"o1": TokenPricingModel(
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
),
"o1-pro": TokenPricingModel(
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
),
"o1-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o3": TokenPricingModel(
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
),
"o3-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o4-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"computer-use-preview": TokenPricingModel(
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
),
"gpt-image-1": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
),
"codex-mini-latest": TokenPricingModel(
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
),
# Transcription models
"gpt-4o-transcribe": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-transcribe": TokenPricingModel(
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
),
# TTS models with token-based pricing
"gpt-4o-mini-tts": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("0")
/ 1000000, # No completion tokens for TTS
),
},
ServiceProviders.GROQ: {
"llama-3.3-70b-versatile": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
),
"deepseek-r1-distill-llama-70b": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
completion_token_price=Decimal("0.00079") / 1000,
),
},
ServiceProviders.AZURE: {
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("8.80")
/ 1000000, # $1.60 per 1M tokens if using data zone
)
},
}

View file

@ -1,89 +0,0 @@
"""
Base pricing models for different service types.
"""
from decimal import Decimal
from enum import Enum
from typing import Any, Dict
class CostType(Enum):
LLM_TOKENS = "llm_tokens"
TTS_CHARACTERS = "tts_characters"
STT_SECONDS = "stt_seconds"
class PricingModel:
"""Base class for pricing models"""
def calculate_cost(self, usage: Any) -> Decimal:
"""Calculate cost based on usage"""
raise NotImplementedError
class TokenPricingModel(PricingModel):
"""Pricing model for token-based services (LLM)"""
def __init__(
self,
prompt_token_price: Decimal,
completion_token_price: Decimal,
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
cache_creation_multiplier: Decimal = Decimal(
"1.25"
), # 25% premium for cache creation
):
self.prompt_token_price = prompt_token_price
self.completion_token_price = completion_token_price
self.cache_read_discount = cache_read_discount
self.cache_creation_multiplier = cache_creation_multiplier
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
"""Calculate cost for LLM token usage"""
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
# Base cost
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
completion_cost = Decimal(completion_tokens) * self.completion_token_price
# Cache adjustments
cache_read_savings = (
Decimal(cache_read_tokens)
* self.prompt_token_price
* self.cache_read_discount
)
cache_creation_premium = (
Decimal(cache_creation_tokens)
* self.prompt_token_price
* (self.cache_creation_multiplier - 1)
)
total_cost = (
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
)
return max(total_cost, Decimal("0")) # Ensure non-negative
class CharacterPricingModel(PricingModel):
"""Pricing model for character-based services (TTS)"""
def __init__(self, character_price: Decimal):
self.character_price = character_price
def calculate_cost(self, character_count: int) -> Decimal:
"""Calculate cost for TTS character usage"""
return Decimal(character_count) * self.character_price
class TimePricingModel(PricingModel):
"""Pricing model for time-based services (STT)"""
def __init__(self, second_price: Decimal):
self.second_price = second_price
def calculate_cost(self, seconds: float) -> Decimal:
"""Calculate cost for STT time usage"""
return Decimal(str(seconds)) * self.second_price

View file

@ -1,18 +0,0 @@
"""
Main pricing registry that combines all service type pricing models.
"""
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
# Combined pricing registry for all service types
PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -1,13 +0,0 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}

View file

@ -1,26 +0,0 @@
"""
STT (Speech-to-Text) pricing models for different providers.
Prices are per second for STT services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TimePricingModel
# STT pricing registry
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
ServiceProviders.DEEPGRAM: {
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
"default": TimePricingModel(Decimal("0.0077") / 60),
},
ServiceProviders.OPENAI: {
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
"default": TimePricingModel(Decimal("0.015") / 60),
},
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
}

View file

@ -1,30 +0,0 @@
"""
TTS (Text-to-Speech) pricing models for different providers.
Prices are per character for TTS services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import CharacterPricingModel
# TTS pricing registry
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
},
ServiceProviders.DEEPGRAM: {
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
},
ServiceProviders.ELEVENLABS: {
# 6400 usd per 250*1e6 characters
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
},
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
}

View file

@ -1,230 +0,0 @@
from decimal import Decimal
from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunMode
from api.services.pricing.cost_calculator import cost_calculator
from api.services.telephony.factory import get_telephony_provider_for_run
async def _fetch_telephony_cost(workflow_run) -> dict | None:
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
if (
workflow_run.mode
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
or not workflow_run.cost_info
):
return None
call_id = workflow_run.cost_info.get("call_id")
if not call_id:
logger.warning(f"call_id not found in cost_info")
return None
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning("Workflow not found for workflow run")
raise Exception("Workflow not found")
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
call_cost_info = await provider.get_call_cost(call_id)
if call_cost_info.get("status") == "error":
logger.error(
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
)
return None
cost_usd = call_cost_info.get("cost_usd", 0.0)
logger.info(
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
)
return {"cost_usd": cost_usd, "provider_name": provider_name}
async def _update_organization_usage(
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
) -> None:
"""Update organization usage after a workflow run."""
org_id = org.id
await db_client.update_usage_after_run(
org_id, dograh_tokens, duration_seconds, charge_usd
)
if charge_usd is not None:
logger.info(
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
)
else:
logger.info(
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
)
async def _get_pricing_organization(workflow_run):
workflow = getattr(workflow_run, "workflow", None)
organization_id = getattr(workflow, "organization_id", None)
if organization_id is None and workflow and workflow.user:
organization_id = workflow.user.selected_organization_id
if organization_id is None:
return None
return await db_client.get_organization_by_id(organization_id)
async def _build_usage_cost_snapshot(
usage_info: dict | None,
*,
workflow_run=None,
include_telephony_cost: bool = False,
organization=None,
calculated_at: str | None = None,
) -> dict | None:
if not usage_info:
logger.warning("No usage info available for workflow run")
return None
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
if include_telephony_cost and workflow_run is not None:
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
total_cost_usd = Decimal(str(cost_breakdown["total"]))
dograh_tokens = float(total_cost_usd * Decimal("100"))
if organization is None and workflow_run is not None:
organization = await _get_pricing_organization(workflow_run)
charge_usd = None
if organization and organization.price_per_second_usd:
duration_seconds = usage_info.get("call_duration_seconds", 0)
charge_usd = float(
Decimal(str(duration_seconds))
* Decimal(str(organization.price_per_second_usd))
)
cost_info = {
"cost_breakdown": cost_breakdown,
"total_cost_usd": float(total_cost_usd),
"dograh_token_usage": dograh_tokens,
"calculated_at": calculated_at
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
}
if charge_usd is not None:
cost_info["charge_usd"] = charge_usd
cost_info["price_per_second_usd"] = organization.price_per_second_usd
return cost_info
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
cost_info = await _build_usage_cost_snapshot(
workflow_run.usage_info,
workflow_run=workflow_run,
include_telephony_cost=True,
calculated_at=workflow_run.created_at.isoformat(),
)
if cost_info is None:
return None
return {
**(workflow_run.cost_info or {}),
**cost_info,
}
async def save_workflow_run_cost_info(
workflow_run_id: int, cost_info: dict | None
) -> None:
if cost_info is None:
return
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
async def apply_workflow_run_usage_to_organization(
workflow_run, cost_info: dict | None
) -> None:
if cost_info is None:
return
org = await _get_pricing_organization(workflow_run)
if not org:
return
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
async def apply_usage_delta_to_organization(
workflow_run, usage_info: dict | None
) -> dict | None:
org = await _get_pricing_organization(workflow_run)
if not org:
return None
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
if cost_info is None:
return None
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
return cost_info
async def calculate_workflow_run_cost(workflow_run_id: int):
logger.debug("Calculating cost for workflow run")
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning("Workflow run not found")
return
try:
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is None:
return
await save_workflow_run_cost_info(workflow_run_id, cost_info)
try:
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
except Exception as e:
org = await _get_pricing_organization(workflow_run)
if org:
logger.error(
f"Failed to update organization usage for org {org.id}: {e}"
)
else:
logger.error(f"Failed to update organization usage: {e}")
# Don't fail the whole cost calculation if usage update fails
logger.info(
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
)
except Exception as e:
logger.error(f"Error calculating cost for workflow run: {e}")
raise

View file

@ -5,15 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
"""
from dataclasses import dataclass
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
get_dograh_service_api_key,
uses_managed_model_services_v2,
)
from api.services.mps_service_key_client import mps_service_key_client
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
)
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your Dograh credits. "
"Please purchase more credits from /billing "
"or change providers in Models configurations."
)
@dataclass
class QuotaCheckResult:
@ -24,104 +47,359 @@ class QuotaCheckResult:
error_code: str = ""
async def check_dograh_quota(
user: UserModel, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check if user has sufficient Dograh quota for making a call.
This function checks if the user is using any Dograh services (LLM, STT, TTS)
and validates that they have sufficient credits remaining.
When ``workflow_id`` is provided, the workflow's per-workflow
``model_overrides`` are merged onto the user's global config so the quota
check runs against the credentials that will actually be used for the call
(rather than always falling back to the user's defaults).
Args:
user: The user to check quota for
workflow_id: Optional workflow whose ``model_overrides`` should be
applied when resolving the effective service config.
Returns:
QuotaCheckResult with has_quota=True if user has sufficient quota or
is not using Dograh services, or has_quota=False with error_message
if quota is insufficient.
"""
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
# Get user configurations
user_config = await db_client.get_user_configurations(user.id)
return float(value)
except (TypeError, ValueError):
return default
if workflow_id is not None:
workflow = await db_client.get_workflow_by_id(workflow_id)
if workflow:
model_overrides = (workflow.workflow_configurations or {}).get(
"model_overrides"
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="insufficient_credits",
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
)
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
)
def _service_uses_dograh(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _dograh_api_keys(user_config: Any) -> set[str]:
api_keys: set[str] = set()
for section_name in ("llm", "stt", "tts", "embeddings"):
service = getattr(user_config, section_name, None)
if not _service_uses_dograh(service):
continue
if hasattr(service, "get_all_api_keys"):
all_api_keys = [
api_key
for api_key in service.get_all_api_keys()
if isinstance(api_key, str) and api_key
]
if all_api_keys:
api_keys.update(all_api_keys)
continue
api_key = getattr(service, "api_key", None)
if api_key:
api_keys.add(api_key)
return api_keys
async def _store_run_correlation_id(
workflow_run_id: int | None,
correlation_id: str | None,
) -> None:
if not workflow_run_id or not correlation_id:
return
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Could not store MPS correlation id for missing workflow run {}",
workflow_run_id,
)
return
initial_context = dict(workflow_run.initial_context or {})
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
return
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
async def _authorize_hosted_workflow_run_start(
*,
workflow_owner: UserModel,
organization_id: int | None,
workflow_id: int | None,
workflow_run_id: int | None,
user_config: Any,
) -> tuple[QuotaCheckResult, bool]:
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
if DEPLOYMENT_MODE == "oss" or organization_id is None:
return QuotaCheckResult(has_quota=True), False
requires_correlation = bool(
workflow_run_id and uses_managed_model_services_v2(user_config)
)
service_key = (
get_dograh_service_api_key(user_config) if requires_correlation else None
)
if requires_correlation and not service_key:
return (
QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
),
True,
)
try:
authorization = await mps_service_key_client.authorize_workflow_run_start(
organization_id=organization_id,
workflow_run_id=workflow_run_id,
service_key=service_key,
require_correlation_id=requires_correlation,
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by=(
str(workflow_owner.provider_id)
if workflow_owner.provider_id is not None
else None
),
metadata={
"dograh_user_id": str(workflow_owner.id),
"workflow_id": workflow_id,
},
)
except Exception as e:
logger.error(
"Failed to authorize workflow start with MPS for org {}: {}",
organization_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
billing_mode = authorization.get("billing_mode")
if billing_mode != "v2":
return QuotaCheckResult(has_quota=True), False
remaining = _safe_float(authorization.get("remaining_credits"))
if (
not authorization.get("allowed", False)
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
):
logger.warning(
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return _insufficient_billing_v2_quota_result(), True
try:
await _store_run_correlation_id(
workflow_run_id,
authorization.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to store MPS correlation id for workflow_run_id {}: {}",
workflow_run_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
logger.info(
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return QuotaCheckResult(has_quota=True), True
async def _authorize_legacy_dograh_keys(
*,
dograh_api_keys: set[str],
organization_id: int | None,
workflow_owner: UserModel,
) -> QuotaCheckResult:
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key,
organization_id=organization_id,
created_by=workflow_owner.provider_id,
)
remaining = usage.get("remaining_credits", 0.0)
# Require at least $0.10 for a short call
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
if model_overrides:
user_config = resolve_effective_config(user_config, model_overrides)
return _insufficient_legacy_quota_result()
# Check if user is using any Dograh service
using_dograh = False
dograh_api_keys = set()
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.llm.api_key)
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.stt.api_key)
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return QuotaCheckResult(has_quota=True)
# Check quota for ALL Dograh keys
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key, created_by=user.provider_id
)
remaining = usage.get("remaining_credits", 0.0)
# Require at least $0.10 for a short call
if remaining < 0.10:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=(
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
),
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def _authorize_oss_managed_v2_correlation(
*,
workflow_id: int,
workflow_run_id: int | None,
user_config: Any,
) -> QuotaCheckResult:
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
return QuotaCheckResult(has_quota=True)
service_key = get_dograh_service_api_key(user_config)
if not service_key:
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
)
try:
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
await _store_run_correlation_id(
workflow_run_id,
response.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
workflow_id,
workflow_run_id,
e,
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def authorize_workflow_run_start(
*,
workflow_id: int,
workflow_run_id: int | None = None,
actor_user: UserModel | None = None,
) -> QuotaCheckResult:
"""Authorize a workflow run before any billable call/text runtime starts.
The workflow organization is the billing subject for hosted v2. The workflow
owner is used only to resolve the effective model configuration and legacy
service-key metadata.
"""
try:
workflow = await db_client.get_workflow_by_id(workflow_id)
if not workflow:
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
actor_org_id = getattr(actor_user, "selected_organization_id", None)
if actor_org_id is not None and actor_org_id != workflow.organization_id:
logger.warning(
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
actor_org_id,
workflow_id,
workflow.organization_id,
)
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
if not workflow_owner:
return QuotaCheckResult(
has_quota=False,
error_code="user_not_found",
error_message="User not found",
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_owner.id,
organization_id=workflow.organization_id,
workflow_configurations=workflow.workflow_configurations,
)
if DEPLOYMENT_MODE != "oss":
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
workflow_owner=workflow_owner,
organization_id=workflow.organization_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
if hosted_enforced or not hosted_result.has_quota:
return hosted_result
dograh_api_keys = _dograh_api_keys(user_config)
if not dograh_api_keys:
return QuotaCheckResult(has_quota=True)
legacy_result = await _authorize_legacy_dograh_keys(
dograh_api_keys=dograh_api_keys,
organization_id=(
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
),
workflow_owner=workflow_owner,
)
if not legacy_result.has_quota:
return legacy_result
if DEPLOYMENT_MODE == "oss":
return await _authorize_oss_managed_v2_correlation(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
return QuotaCheckResult(has_quota=True)
@ -129,30 +407,3 @@ async def check_dograh_quota(
logger.error(f"Error during quota check: {str(e)}")
# On unexpected error, allow the call to proceed
return QuotaCheckResult(has_quota=True)
async def check_dograh_quota_by_user_id(
user_id: int, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check Dograh quota by user ID.
Convenience function that fetches the user and then checks quota. When
``workflow_id`` is provided, the workflow's ``model_overrides`` are
applied so the quota check evaluates the credentials that will actually
be used for the call.
Args:
user_id: The ID of the user to check quota for
workflow_id: Optional workflow whose per-workflow overrides should
be applied to the user's config before checking quota.
Returns:
QuotaCheckResult with quota status
"""
user = await db_client.get_user_by_id(user_id)
if not user:
return QuotaCheckResult(
has_quota=False,
error_message="User not found",
)
return await check_dograh_quota(user, workflow_id=workflow_id)

View file

@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
for run in runs:
initial = run.initial_context or {}
gathered = run.gathered_context or {}
cost = run.cost_info or {}
usage = run.usage_info or {}
call_tags = gathered.get("call_tags", [])
if isinstance(call_tags, list):
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
run.created_at.isoformat() if run.created_at else "",
initial.get("phone_number", ""),
gathered.get("mapped_call_disposition", ""),
cost.get("call_duration_seconds", ""),
usage.get("call_duration_seconds", ""),
]
extracted = gathered.get("extracted_variables", {})

View file

@ -26,7 +26,7 @@ from loguru import logger
from api.constants import REDIS_URL
from api.db import db_client
from api.enums import CallType, WorkflowRunMode
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.transfer_event_protocol import (
TransferEvent,
@ -564,19 +564,7 @@ class ARIConnection:
user_id = workflow.user_id
# 3. Check quota (apply per-workflow model_overrides).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=inbound_workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 4. Create workflow run
# 3. Create workflow run
call_id = channel_id
workflow_run = await db_client.create_workflow_run(
name=f"ARI Inbound {caller_number}",
@ -602,6 +590,20 @@ class ARIConnection:
f"(caller={caller_number}, called={called_number})"
)
# 4. Check quota after the run exists so hosted v2 can mint and
# store the MPS correlation id before the pipeline starts.
quota_result = await authorize_workflow_run_start(
workflow_id=inbound_workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 5. Answer the inbound channel
await self._answer_channel(channel_id)

View file

@ -103,7 +103,8 @@ async def handle_cloudonix_cdr(request: Request):
return {"status": "error", "message": "Missing domain field"}
# Extract call_id to find workflow run
call_id = cdr_data.get("session").get("token")
session = cdr_data.get("session")
call_id = session.get("token") if isinstance(session, dict) else None
logger.info(f"Cloudonix CDR data for call id {call_id} - {cdr_data}")
if not call_id:
logger.warning("Cloudonix CDR missing call_id field")

View file

@ -6,9 +6,8 @@ provider registry — see ProviderSpec.router.
import json
from datetime import UTC, datetime
from typing import Optional
from fastapi import APIRouter, Header, Request
from fastapi import APIRouter, HTTPException, Request
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from starlette.responses import HTMLResponse
@ -29,6 +28,30 @@ from api.utils.telephony_helper import (
router = APIRouter()
async def _verify_vobiz_callback(
provider,
webhook_url: str,
callback_data: dict,
headers: dict,
raw_body: str,
*,
log_prefix: str,
) -> None:
"""Verify a Vobiz callback signature, failing closed.
Vobiz signs every callback, so a missing signature header is an invalid
request ``provider.verify_inbound_signature`` returns ``False`` for both
missing and forged signatures. Reject with HTTP 403 (per Vobiz's
callback-validation docs) so the caller never reaches status processing.
"""
is_valid = await provider.verify_inbound_signature(
webhook_url, callback_data, headers, raw_body
)
if not is_valid:
logger.warning(f"{log_prefix} Invalid or missing Vobiz callback signature")
raise HTTPException(status_code=403, detail="Invalid webhook signature")
@router.post("/vobiz-xml", include_in_schema=False)
async def handle_vobiz_xml_webhook(
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
@ -65,8 +88,6 @@ async def handle_vobiz_xml_webhook(
async def handle_vobiz_hangup_callback(
workflow_run_id: int,
request: Request,
x_vobiz_signature: Optional[str] = Header(None),
x_vobiz_timestamp: Optional[str] = Header(None),
):
"""Handle Vobiz hangup callback (sent when call ends).
@ -75,82 +96,23 @@ async def handle_vobiz_hangup_callback(
"""
set_current_run_id(workflow_run_id)
# Logging all headers and body to understand what Vobiz actually sends
all_headers = dict(request.headers)
logger.info(
f"[run {workflow_run_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
)
# Parse the callback data from the raw body so signed webhooks can verify
# the exact bytes Vobiz sent without draining the request stream first.
callback_data, raw_body = await parse_webhook_request(request)
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
logger.info(
f"[run {workflow_run_id}] Vobiz hangup callback - Body: {json.dumps(callback_data)}"
)
logger.info(
f"[run {workflow_run_id}] Received Vobiz hangup callback {json.dumps(callback_data)}"
)
# Verify signature if Vobiz provided any supported signature header.
has_vobiz_signature = any(
header in all_headers
for header in (
"x-vobiz-signature-v3",
"x-vobiz-signature-ma-v3",
"x-vobiz-signature-v2",
"x-vobiz-signature-ma-v2",
)
)
if has_vobiz_signature:
# We need the workflow run to get organization for provider credentials
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for signature verification"
)
return {"status": "error", "reason": "workflow_run_not_found"}
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(
f"[run {workflow_run_id}] Workflow not found for signature verification"
)
return {"status": "error", "reason": "workflow_not_found"}
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
# Verify signature
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
is_valid = await provider.verify_inbound_signature(
webhook_url,
callback_data,
all_headers,
raw_body,
)
if not is_valid:
logger.warning(
f"[run {workflow_run_id}] Invalid Vobiz hangup callback signature"
)
return {"status": "error", "reason": "invalid_signature"}
logger.info(f"[run {workflow_run_id}] Vobiz hangup callback signature verified")
else:
# Get workflow run for processing (signature verification already got it if needed)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for Vobiz hangup callback"
)
return {"status": "ignored", "reason": "workflow_run_not_found"}
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(f"[run {workflow_run_id}] Workflow not found")
@ -160,6 +122,21 @@ async def handle_vobiz_hangup_callback(
workflow_run, workflow.organization_id
)
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
# before they can mutate call state.
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = (
f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
)
await _verify_vobiz_callback(
provider,
webhook_url,
callback_data,
all_headers,
raw_body,
log_prefix=f"[run {workflow_run_id}]",
)
logger.debug(
f"[run {workflow_run_id}] Processing Vobiz hangup with provider: {provider.PROVIDER_NAME}"
)
@ -167,10 +144,6 @@ async def handle_vobiz_hangup_callback(
# Parse the callback data into generic format
parsed_data = provider.parse_status_callback(callback_data)
logger.debug(
f"[run {workflow_run_id}] Parsed Vobiz callback data: {json.dumps(parsed_data)}"
)
# Create StatusCallbackRequest from parsed data
status_update = StatusCallbackRequest(
call_id=parsed_data["call_id"],
@ -194,8 +167,6 @@ async def handle_vobiz_hangup_callback(
async def handle_vobiz_ring_callback(
workflow_run_id: int,
request: Request,
x_vobiz_signature: Optional[str] = Header(None),
x_vobiz_timestamp: Optional[str] = Header(None),
):
"""Handle Vobiz ring callback (sent when call starts ringing).
@ -204,84 +175,46 @@ async def handle_vobiz_ring_callback(
"""
set_current_run_id(workflow_run_id)
# Logging all headers and body to understand what Vobiz actually sends
all_headers = dict(request.headers)
logger.info(
f"[run {workflow_run_id}] Vobiz ring callback - Headers: {json.dumps(all_headers)}"
)
# Parse the callback data from the raw body so signed webhooks can verify
# the exact bytes Vobiz sent without draining the request stream first.
callback_data, raw_body = await parse_webhook_request(request)
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
logger.info(
f"[run {workflow_run_id}] Vobiz ring callback - Body: {json.dumps(callback_data)}"
)
logger.info(
f"[run {workflow_run_id}] Received Vobiz ring callback {json.dumps(callback_data)}"
)
# Verify signature if Vobiz provided any supported signature header.
has_vobiz_signature = any(
header in all_headers
for header in (
"x-vobiz-signature-v3",
"x-vobiz-signature-ma-v3",
"x-vobiz-signature-v2",
"x-vobiz-signature-ma-v2",
)
)
if has_vobiz_signature:
# We need the workflow run to get organization for provider credentials
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for signature verification"
)
return {"status": "error", "reason": "workflow_run_not_found"}
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(
f"[run {workflow_run_id}] Workflow not found for signature verification"
)
return {"status": "error", "reason": "workflow_not_found"}
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
# Verify signature
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = (
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
)
is_valid = await provider.verify_inbound_signature(
webhook_url,
callback_data,
all_headers,
raw_body,
)
if not is_valid:
logger.warning(
f"[run {workflow_run_id}] Invalid Vobiz ring callback signature"
)
return {"status": "error", "reason": "invalid_signature"}
logger.info(f"[run {workflow_run_id}] Vobiz ring callback signature verified")
else:
# Get workflow run for processing (signature verification already got it if needed)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for Vobiz ring callback"
)
return {"status": "ignored", "reason": "workflow_run_not_found"}
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(f"[run {workflow_run_id}] Workflow not found")
return {"status": "ignored", "reason": "workflow_not_found"}
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
# Fail closed: reject unsigned/forged ring callbacks before logging them.
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = (
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
)
await _verify_vobiz_callback(
provider,
webhook_url,
callback_data,
all_headers,
raw_body,
log_prefix=f"[run {workflow_run_id}]",
)
# Log the ringing event
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
ring_log = {
@ -308,15 +241,10 @@ async def handle_vobiz_ring_callback(
async def handle_vobiz_hangup_callback_by_workflow(
workflow_id: int,
request: Request,
x_vobiz_signature: Optional[str] = Header(None),
x_vobiz_timestamp: Optional[str] = Header(None),
):
"""Handle Vobiz hangup callback with workflow_id - finds workflow run by call_id."""
all_headers = dict(request.headers)
logger.info(
f"[workflow {workflow_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
)
try:
callback_data, raw_body = await parse_webhook_request(request)
@ -364,35 +292,18 @@ async def handle_vobiz_hangup_callback_by_workflow(
workflow_run, workflow.organization_id
)
has_vobiz_signature = any(
header in all_headers
for header in (
"x-vobiz-signature-v3",
"x-vobiz-signature-ma-v3",
"x-vobiz-signature-v2",
"x-vobiz-signature-ma-v2",
)
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
# before they can mutate call state.
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
await _verify_vobiz_callback(
provider,
webhook_url,
callback_data,
all_headers,
raw_body,
log_prefix=f"[workflow {workflow_id}]",
)
if has_vobiz_signature:
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
is_valid = await provider.verify_inbound_signature(
webhook_url,
callback_data,
all_headers,
raw_body,
)
if not is_valid:
logger.warning(
f"[workflow {workflow_id}] Invalid Vobiz hangup callback signature"
)
return {"status": "error", "message": "invalid_signature"}
logger.info(
f"[workflow {workflow_id}] Vobiz hangup callback signature verified"
)
try:
parsed_data = provider.parse_status_callback(callback_data)

View file

@ -66,34 +66,6 @@ async def handle_vonage_events(
logger.error(f"[run {workflow_run_id}] Workflow run not found")
return {"status": "error", "message": "Workflow run not found"}
# For a completed call that includes cost info, capture it immediately
if event_data.get("status") == "completed":
# Vonage sometimes includes price info in the webhook
if "price" in event_data or "rate" in event_data:
try:
if workflow_run.cost_info:
# Store immediate cost info if available
cost_info = workflow_run.cost_info.copy()
if "price" in event_data:
cost_info["vonage_webhook_price"] = float(event_data["price"])
if "rate" in event_data:
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
if "duration" in event_data:
cost_info["vonage_webhook_duration"] = int(
event_data["duration"]
)
await db_client.update_workflow_run(
run_id=workflow_run_id, cost_info=cost_info
)
logger.info(
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
)
except Exception as e:
logger.error(
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
)
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:

View file

@ -114,11 +114,13 @@ class StatusCallbackRequest(BaseModel):
"NOANSWER": "no-answer",
}
disposition = data.get("disposition", "")
disposition = data.get("disposition") or ""
status = disposition_map.get(disposition.upper(), disposition.lower())
session = data.get("session")
call_id = session.get("token") if isinstance(session, dict) else ""
return cls(
call_id=data.get("session").get("token"),
call_id=call_id or "",
status=status,
from_number=data.get("from"),
to_number=data.get("to"),

View file

@ -35,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
@ -382,6 +383,9 @@ class PipecatEngine:
embeddings_provider=self._embeddings_provider,
embeddings_endpoint=self._embeddings_endpoint,
embeddings_api_version=self._embeddings_api_version,
correlation_id=self._call_context_vars.get(
MPS_CORRELATION_ID_CONTEXT_KEY
),
tracing_context=self._get_otel_context(),
)

View file

@ -2,7 +2,6 @@
import random
from api.db import db_client
from api.db.models import WorkflowRunModel
from api.services.workflow.dto import QANodeData
@ -43,7 +42,7 @@ async def resolve_llm_config(
async def resolve_user_llm_config(
workflow_run: WorkflowRunModel,
) -> tuple[str, str, str, dict]:
"""Resolve the user's configured LLM (from UserConfiguration).
"""Resolve the user's configured LLM (from EffectiveAIModelConfiguration).
Returns:
(provider, model, api_key, service_kwargs) tuple
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
llm_config: dict = {}
if user_id:
user_configuration = await db_client.get_user_configurations(user_id)
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_configurations = {}
if workflow_run.definition:
workflow_configurations = (
workflow_run.definition.workflow_configurations or {}
)
elif workflow_run.workflow:
workflow_configurations = (
workflow_run.workflow.workflow_configurations or {}
)
user_configuration = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow_run.workflow.organization_id
if workflow_run.workflow
else None,
workflow_configurations=workflow_configurations,
)
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
provider = llm_config.get("provider", "openai")

View file

@ -0,0 +1,41 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}
def format_public_cost_info(
cost_info: dict | None, usage_info: dict | None
) -> dict | None:
"""Return the legacy response shape without doing local cost accounting."""
duration = None
if usage_info and usage_info.get("call_duration_seconds") is not None:
duration = int(round(usage_info.get("call_duration_seconds") or 0))
elif cost_info and cost_info.get("call_duration_seconds") is not None:
duration = int(round(cost_info.get("call_duration_seconds") or 0))
dograh_token_usage = 0
if cost_info:
if "dograh_token_usage" in cost_info:
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
elif "total_cost_usd" in cost_info:
dograh_token_usage = round(
float(cost_info.get("total_cost_usd", 0)) * 100, 2
)
if duration is None and dograh_token_usage == 0:
return None
return {
"dograh_token_usage": dograh_token_usage,
"call_duration_seconds": duration,
}

View file

@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
from api.db import db_client
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.configuration.resolve import resolve_effective_config
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.pipeline_builder import create_pipeline_task
from api.services.pipecat.pipeline_metrics_aggregator import (
@ -410,14 +409,31 @@ async def execute_text_chat_pending_turn(
run_definition = workflow_run.definition
run_configs = run_definition.workflow_configurations or {}
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_run.workflow.user.id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
base_initial_context = dict(workflow_run.initial_context or {})
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=base_initial_context,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = llm
runtime_configuration = {
@ -425,9 +441,15 @@ async def execute_text_chat_pending_turn(
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
**base_initial_context,
"runtime_configuration": runtime_configuration,
}
if mps_correlation_id:
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)
@ -466,9 +488,17 @@ async def execute_text_chat_pending_turn(
embeddings_model = None
embeddings_base_url = None
if user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
@ -606,8 +636,10 @@ async def execute_text_chat_pending_turn(
"Transportless text chat pipeline failed while closing run {}",
workflow_run_id,
)
await engine.close_mcp_sessions()
await engine.cleanup()
raise
await engine.close_mcp_sessions()
await engine.cleanup()
gathered_context = await engine.get_gathered_context()

View file

@ -4,17 +4,11 @@ from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from api.db import db_client
from api.db.models import WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
)
from api.services.workflow.text_chat_logs import (
build_text_chat_realtime_feedback_events,
)
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
state=execution.state,
is_completed=execution.is_completed,
)
workflow_run = await db_client.get_workflow_run_by_id(run_id)
if workflow_run:
try:
# Apply the per-turn delta so org usage tracks cumulative run cost
# without replaying the full session totals on every turn.
await apply_usage_delta_to_organization(workflow_run, execution.usage)
except Exception as e:
logger.error(
f"Failed to update organization usage for text chat run {run_id}: {e}"
)
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is not None:
await db_client.update_workflow_run(run_id, cost_info=cost_info)
return await _reload_text_chat_session(run_id)

View file

@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
tracing_context=None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Create span with parent context
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Add result metadata to span
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
else:
# Tracing is disabled - perform retrieval without tracing
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
@ -220,6 +225,7 @@ async def _perform_retrieval(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
@ -272,11 +278,20 @@ async def _perform_retrieval(
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
results = await embedding_service.search_similar_chunks(

View file

@ -0,0 +1,111 @@
"""Workflow-run billing hooks.
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
For hosted deployments, Dograh reports completed platform usage to MPS.
When a server-minted MPS correlation id exists, MPS uses model-service usage
as the canonical duration. Otherwise Dograh reports the completed run duration.
"""
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.services.managed_model_services import get_mps_correlation_id
from api.services.mps_service_key_client import mps_service_key_client
def _workflow_run_organization_id(workflow_run) -> int | None:
workflow = getattr(workflow_run, "workflow", None)
return getattr(workflow, "organization_id", None)
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
duration = usage_info.get("call_duration_seconds")
try:
duration_seconds = float(duration)
except (TypeError, ValueError):
return None
return duration_seconds if duration_seconds > 0 else None
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
account = await mps_service_key_client.get_billing_account_status(
organization_id=organization_id
)
return bool(account and account.get("billing_mode") == "v2")
async def report_workflow_run_platform_usage(workflow_run) -> None:
"""Report hosted platform usage for a completed workflow run to MPS."""
if DEPLOYMENT_MODE == "oss":
return
if not getattr(workflow_run, "is_completed", False):
return
organization_id = _workflow_run_organization_id(workflow_run)
if organization_id is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no organization_id",
workflow_run.id,
)
return
correlation_id = get_mps_correlation_id(
getattr(workflow_run, "initial_context", None)
)
duration_seconds = (
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
)
if not correlation_id and duration_seconds is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no billable duration",
workflow_run.id,
)
return
try:
if not await _organization_uses_mps_billing_v2(organization_id):
return
result = await mps_service_key_client.report_platform_usage(
organization_id=organization_id,
correlation_id=correlation_id,
duration_seconds=duration_seconds,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": getattr(workflow_run, "workflow_id", None),
"duration_source": (
"mps_correlation" if correlation_id else "dograh_usage_info"
),
},
)
logger.info(
"Reported platform usage for workflow run {} to MPS: {}",
workflow_run.id,
result,
)
except Exception as e:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
"""Load a completed workflow run and report platform usage to MPS."""
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Skipping platform usage report: workflow run {} not found",
workflow_run_id,
)
return
await report_workflow_run_platform_usage(workflow_run)

View file

@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import (
)
from api.tasks.knowledge_base_processing import process_knowledge_base_document
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.tasks.s3_upload import (
process_workflow_completion,
upload_voicemail_audio_to_s3,
)
from api.tasks.s3_upload import upload_voicemail_audio_to_s3
from api.tasks.workflow_completion import process_workflow_completion
class WorkerSettings:

View file

@ -157,15 +157,31 @@ async def process_knowledge_base_document(
embeddings_endpoint = None
embeddings_api_version = None
if document.created_by:
user_config = await db_client.get_user_configurations(document.created_by)
if user_config.embeddings:
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=document.created_by,
organization_id=document.organization_id,
)
effective_config = resolved_config.effective
if effective_config.embeddings:
embeddings_provider = getattr(
effective_config.embeddings, "provider", None
)
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(
effective_config.embeddings, "endpoint", None
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
logger.info(
f"Using user embeddings config: provider={embeddings_provider}, "

View file

@ -1,13 +1,9 @@
import os
from typing import Optional
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from api.db import db_client
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
from api.services.storage import get_current_storage_backend, storage_fs
from api.tasks.run_integrations import run_integrations_post_workflow_run
from api.services.storage import storage_fs
async def upload_voicemail_audio_to_s3(
@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3(
logger.warning(
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
)
async def process_workflow_completion(
_ctx,
workflow_run_id: int,
audio_temp_path: Optional[str] = None,
transcript_temp_path: Optional[str] = None,
):
"""Process workflow completion: upload artifacts and run integrations.
This task combines audio upload, transcript upload, and webhook integrations
into a single sequential task to ensure integrations run after uploads complete.
Args:
_ctx: ARQ context (unused)
workflow_run_id: The workflow run ID
audio_temp_path: Optional path to temp audio file
transcript_temp_path: Optional path to temp transcript file
"""
run_id = str(workflow_run_id)
set_current_run_id(run_id)
logger.info(f"Processing workflow completion for run {workflow_run_id}")
storage_backend = get_current_storage_backend()
# Step 1: Upload audio if provided
if audio_temp_path:
try:
if os.path.exists(audio_temp_path):
file_size = os.path.getsize(audio_temp_path)
logger.debug(f"Audio file size: {file_size} bytes")
recording_url = f"recordings/{workflow_run_id}.wav"
logger.info(
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(audio_temp_path, recording_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
recording_url=recording_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded audio: {recording_url}")
else:
logger.warning(f"Audio temp file not found: {audio_temp_path}")
except Exception as e:
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
finally:
if audio_temp_path and os.path.exists(audio_temp_path):
try:
os.remove(audio_temp_path)
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp audio file: {e}")
# Step 2: Upload transcript if provided
if transcript_temp_path:
try:
if os.path.exists(transcript_temp_path):
file_size = os.path.getsize(transcript_temp_path)
logger.debug(f"Transcript file size: {file_size} bytes")
transcript_url = f"transcripts/{workflow_run_id}.txt"
logger.info(
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
transcript_url=transcript_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded transcript: {transcript_url}")
else:
logger.warning(
f"Transcript temp file not found: {transcript_temp_path}"
)
except Exception as e:
logger.error(
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
)
finally:
if transcript_temp_path and os.path.exists(transcript_temp_path):
try:
os.remove(transcript_temp_path)
logger.debug(
f"Cleaned up temp transcript file: {transcript_temp_path}"
)
except Exception as e:
logger.warning(f"Failed to clean up temp transcript file: {e}")
# Step 3: Run integrations including QA analysis (after uploads are complete)
try:
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
except Exception as e:
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
# Step 4: Calculate cost after integrations (so QA token usage is included)
try:
await calculate_workflow_run_cost(workflow_run_id)
except Exception as e:
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")

View file

@ -0,0 +1,121 @@
import os
from typing import Optional
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from api.db import db_client
from api.services.storage import get_current_storage_backend, storage_fs
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
)
from api.tasks.run_integrations import run_integrations_post_workflow_run
async def process_workflow_completion(
_ctx,
workflow_run_id: int,
audio_temp_path: Optional[str] = None,
transcript_temp_path: Optional[str] = None,
):
"""Process workflow completion: upload artifacts and run integrations.
This task combines audio upload, transcript upload, and webhook integrations
into a single sequential task to ensure integrations run after uploads complete.
Args:
_ctx: ARQ context (unused)
workflow_run_id: The workflow run ID
audio_temp_path: Optional path to temp audio file
transcript_temp_path: Optional path to temp transcript file
"""
run_id = str(workflow_run_id)
set_current_run_id(run_id)
logger.info(f"Processing workflow completion for run {workflow_run_id}")
storage_backend = get_current_storage_backend()
# Step 1: Upload audio if provided
if audio_temp_path:
try:
if os.path.exists(audio_temp_path):
file_size = os.path.getsize(audio_temp_path)
logger.debug(f"Audio file size: {file_size} bytes")
recording_url = f"recordings/{workflow_run_id}.wav"
logger.info(
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(audio_temp_path, recording_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
recording_url=recording_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded audio: {recording_url}")
else:
logger.warning(f"Audio temp file not found: {audio_temp_path}")
except Exception as e:
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
finally:
if audio_temp_path and os.path.exists(audio_temp_path):
try:
os.remove(audio_temp_path)
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temp audio file: {e}")
# Step 2: Upload transcript if provided
if transcript_temp_path:
try:
if os.path.exists(transcript_temp_path):
file_size = os.path.getsize(transcript_temp_path)
logger.debug(f"Transcript file size: {file_size} bytes")
transcript_url = f"transcripts/{workflow_run_id}.txt"
logger.info(
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
)
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
await db_client.update_workflow_run(
run_id=workflow_run_id,
transcript_url=transcript_url,
storage_backend=storage_backend.value,
)
logger.info(f"Successfully uploaded transcript: {transcript_url}")
else:
logger.warning(
f"Transcript temp file not found: {transcript_temp_path}"
)
except Exception as e:
logger.error(
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
)
finally:
if transcript_temp_path and os.path.exists(transcript_temp_path):
try:
os.remove(transcript_temp_path)
logger.debug(
f"Cleaned up temp transcript file: {transcript_temp_path}"
)
except Exception as e:
logger.warning(f"Failed to clean up temp transcript file: {e}")
# Step 3: Run integrations including QA analysis (after uploads are complete)
try:
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
except Exception as e:
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
# Step 4: Notify MPS after completion. MPS owns credit accounting.
try:
await report_completed_workflow_run_platform_usage(workflow_run_id)
except Exception as e:
logger.error(
f"Error reporting platform usage for workflow {workflow_run_id}: {e}"
)
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")

View file

@ -203,7 +203,7 @@ async def create_workflow_run_rows(
Returns:
Tuple of (workflow_run, user, workflow).
"""
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
async_session.add(org)
@ -218,7 +218,7 @@ async def create_workflow_run_rows(
await db_session.update_user_configuration(
user_id=user.id,
configuration=UserConfiguration.model_validate(USER_CONFIGURATION),
configuration=EffectiveAIModelConfiguration.model_validate(USER_CONFIGURATION),
)
workflow = await db_session.create_workflow(

View file

@ -0,0 +1,119 @@
"""Regression tests for Cloudonix CDR webhook handling.
A Cloudonix CDR webhook is a public, unauthenticated endpoint that parses
arbitrary external JSON. A partial / malformed payload (missing ``session``,
or a ``null`` ``session`` / ``disposition``) must produce a graceful error
response, not an unhandled ``AttributeError`` (HTTP 500).
"""
from unittest.mock import AsyncMock, patch
import pytest
from starlette.requests import Request
from api.services.telephony.providers.cloudonix.routes import handle_cloudonix_cdr
from api.services.telephony.status_processor import StatusCallbackRequest
def _json_request(body: bytes) -> Request:
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
return Request(
{
"type": "http",
"method": "POST",
"scheme": "https",
"server": ("example.test", 443),
"path": "/api/v1/telephony/cloudonix/cdr",
"query_string": b"",
"headers": [(b"content-type", b"application/json")],
},
receive,
)
@pytest.mark.asyncio
async def test_cdr_route_handles_payload_without_session():
"""A CDR payload missing the ``session`` object returns a graceful error
instead of raising ``AttributeError`` on ``None.get("token")``."""
request = _json_request(b'{"domain": "acme.cloudonix.io", "disposition": "ANSWER"}')
with patch(
"api.services.telephony.providers.cloudonix.routes.db_client"
) as db_client:
db_client.get_workflow_run_by_call_id = AsyncMock(return_value=None)
result = await handle_cloudonix_cdr(request)
assert result == {"status": "error", "message": "Missing call_id field"}
@pytest.mark.asyncio
async def test_cdr_route_handles_null_session():
"""A CDR payload with an explicit ``null`` session is handled gracefully."""
request = _json_request(b'{"domain": "acme.cloudonix.io", "session": null}')
with patch(
"api.services.telephony.providers.cloudonix.routes.db_client"
) as db_client:
db_client.get_workflow_run_by_call_id = AsyncMock(return_value=None)
result = await handle_cloudonix_cdr(request)
assert result == {"status": "error", "message": "Missing call_id field"}
@pytest.mark.asyncio
async def test_cdr_route_handles_string_session():
"""A CDR payload with a non-object session is handled gracefully."""
request = _json_request(b'{"domain": "acme.cloudonix.io", "session": "abc"}')
with patch(
"api.services.telephony.providers.cloudonix.routes.db_client"
) as db_client:
db_client.get_workflow_run_by_call_id = AsyncMock(return_value=None)
result = await handle_cloudonix_cdr(request)
assert result == {"status": "error", "message": "Missing call_id field"}
def test_from_cloudonix_cdr_tolerates_missing_session_and_disposition():
"""``from_cloudonix_cdr`` must not crash on a partial CDR payload."""
# Missing both session and disposition.
req = StatusCallbackRequest.from_cloudonix_cdr({"domain": "acme.cloudonix.io"})
assert req.call_id == ""
assert req.status == ""
# Explicit null values.
req = StatusCallbackRequest.from_cloudonix_cdr(
{"session": None, "disposition": None}
)
assert req.call_id == ""
assert req.status == ""
def test_from_cloudonix_cdr_tolerates_string_session():
"""``from_cloudonix_cdr`` treats a non-object session as missing call_id."""
req = StatusCallbackRequest.from_cloudonix_cdr(
{"session": "abc", "disposition": "ANSWER"}
)
assert req.call_id == ""
assert req.status == "completed"
def test_from_cloudonix_cdr_maps_disposition_and_session_token():
"""Normal, well-formed CDR payloads still map correctly."""
req = StatusCallbackRequest.from_cloudonix_cdr(
{
"session": {"token": "abc123"},
"disposition": "BUSY",
"from": "+15551230001",
"to": "+15551230002",
"billsec": 12,
}
)
assert req.call_id == "abc123"
assert req.status == "busy"
assert req.duration == "12"

View file

@ -6,11 +6,13 @@ from unittest.mock import AsyncMock, patch
from urllib.parse import urlencode
import pytest
from fastapi import HTTPException
from starlette.requests import Request
from api.services.telephony.providers.vobiz.provider import VobizProvider
from api.services.telephony.providers.vobiz.routes import (
handle_vobiz_hangup_callback,
handle_vobiz_hangup_callback_by_workflow,
handle_vobiz_ring_callback,
)
@ -225,3 +227,154 @@ async def test_vobiz_verify_inbound_signature_rejects_missing_signature():
{},
{},
)
@pytest.mark.asyncio
async def test_vobiz_hangup_callback_rejects_missing_signature():
"""An unsigned hangup callback must be rejected before status processing."""
provider = _provider()
form_data = {
"CallUUID": "call-123",
"CallStatus": "completed",
"From": "15551230001",
"To": "15551230002",
"Direction": "outbound",
"Duration": "12",
}
# No x-vobiz-signature-* headers — the callback is unsigned.
request = _request(
path="/api/v1/telephony/vobiz/hangup-callback/123",
form_data=form_data,
)
with (
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
patch(
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
new_callable=AsyncMock,
return_value=provider,
),
patch(
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
new_callable=AsyncMock,
return_value=("https://example.test", "wss://example.test"),
),
patch(
"api.services.telephony.providers.vobiz.routes._process_status_update",
new_callable=AsyncMock,
) as process_status,
):
db_client.get_workflow_run_by_id = AsyncMock(
return_value=SimpleNamespace(workflow_id=7)
)
db_client.get_workflow_by_id = AsyncMock(
return_value=SimpleNamespace(organization_id=11)
)
with pytest.raises(HTTPException) as exc_info:
await handle_vobiz_hangup_callback(
workflow_run_id=123,
request=request,
)
assert exc_info.value.status_code == 403
process_status.assert_not_awaited()
@pytest.mark.asyncio
async def test_vobiz_ring_callback_rejects_missing_signature():
"""An unsigned ring callback must be rejected before it is logged."""
provider = _provider()
form_data = {
"CallUUID": "call-123",
"CallStatus": "ringing",
"From": "15551230001",
"To": "15551230002",
}
# No x-vobiz-signature-* headers — the callback is unsigned.
request = _request(
path="/api/v1/telephony/vobiz/ring-callback/123",
form_data=form_data,
)
workflow_run = SimpleNamespace(workflow_id=7, logs={})
with (
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
patch(
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
new_callable=AsyncMock,
return_value=provider,
),
patch(
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
new_callable=AsyncMock,
return_value=("https://example.test", "wss://example.test"),
),
):
db_client.get_workflow_run_by_id = AsyncMock(return_value=workflow_run)
db_client.get_workflow_by_id = AsyncMock(
return_value=SimpleNamespace(organization_id=11)
)
db_client.update_workflow_run = AsyncMock()
with pytest.raises(HTTPException) as exc_info:
await handle_vobiz_ring_callback(
workflow_run_id=123,
request=request,
)
assert exc_info.value.status_code == 403
db_client.update_workflow_run.assert_not_awaited()
@pytest.mark.asyncio
async def test_vobiz_hangup_callback_by_workflow_rejects_missing_signature():
"""An unsigned by-workflow hangup callback must be rejected before processing."""
provider = _provider()
form_data = {
"CallUUID": "call-123",
"CallStatus": "completed",
"From": "15551230001",
"To": "15551230002",
"Direction": "outbound",
"Duration": "12",
}
# No x-vobiz-signature-* headers — the callback is unsigned.
request = _request(
path="/api/v1/telephony/vobiz/hangup-callback/workflow/7",
form_data=form_data,
)
with (
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
patch(
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
new_callable=AsyncMock,
return_value=provider,
),
patch(
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
new_callable=AsyncMock,
return_value=("https://example.test", "wss://example.test"),
),
patch(
"api.services.telephony.providers.vobiz.routes._process_status_update",
new_callable=AsyncMock,
) as process_status,
):
db_client.get_workflow_by_id = AsyncMock(
return_value=SimpleNamespace(organization_id=11)
)
db_client.get_workflow_run_by_call_id = AsyncMock(
return_value=SimpleNamespace(id=123, workflow_id=7)
)
with pytest.raises(HTTPException) as exc_info:
await handle_vobiz_hangup_callback_by_workflow(
workflow_id=7,
request=request,
)
assert exc_info.value.status_code == 403
process_status.assert_not_awaited()

View file

@ -0,0 +1,459 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
GoogleLLMService,
GoogleRealtimeLLMConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
config = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
voice="default",
speed=1.2,
language="multi",
),
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is False
assert effective.llm.provider == "dograh"
assert effective.llm.model == "default"
assert effective.tts.provider == "dograh"
assert effective.tts.speed == 1.2
assert effective.stt.provider == "dograh"
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
assert effective.managed_service_version == 2
def test_dograh_v2_rejects_non_predefined_speed():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
speed=1.5,
),
)
def test_byok_v2_rejects_dograh_provider():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
"tts": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
"voice": "default",
},
"stt": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
},
},
}
)
@pytest.mark.asyncio
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
config = OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "realtime",
"realtime": {
"realtime": {
"provider": "google_realtime",
"api_key": "google-realtime-key",
"model": "gemini-3.1-flash-live-preview",
"voice": "Puck",
"language": "en",
},
"llm": {
"provider": "google",
"api_key": "google-llm-key",
"model": "gemini-2.0-flash",
},
},
},
}
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is True
assert effective.stt is None
assert effective.tts is None
assert await UserConfigurationValidator().validate(effective) == {
"status": [{"model": "all", "message": "ok"}]
}
@pytest.mark.asyncio
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
effective = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key="google-llm-key",
model="gemini-2.0-flash",
),
realtime=GoogleRealtimeLLMConfiguration(
provider="google_realtime",
api_key="google-realtime-key",
model="gemini-3.1-flash-live-preview",
voice="Puck",
language="en",
),
is_realtime=False,
)
with pytest.raises(ValueError) as exc_info:
await UserConfigurationValidator().validate(effective)
assert exc_info.value.args[0] == [
{"model": "stt", "message": "API key is missing"},
{"model": "tts", "message": "API key is missing"},
]
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
)
incoming = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
)
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
assert merged.dograh.api_key == "mps-real-secret"
check_for_masked_keys_in_ai_model_configuration_v2(merged)
def test_masked_v2_configuration_masks_nested_service_keys():
config = OrganizationAIModelConfigurationV2(
mode="byok",
byok={
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "openai",
"api_key": "sk-real-secret",
"model": "gpt-4.1",
},
"tts": {
"provider": "elevenlabs",
"api_key": "el-real-secret",
"model": "eleven_flash_v2_5",
"voice": "Rachel",
},
"stt": {
"provider": "deepgram",
"api_key": "dg-real-secret",
"model": "nova-3-general",
},
},
},
)
masked = mask_ai_model_configuration_v2(config)
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
speed=1.0,
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
language="multi",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-secret"
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
legacy = EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=DograhTTSService(
provider="dograh",
api_key="mps-tts",
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key="mps-stt",
model="default",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-tts"
assert config.dograh.voice == "default"
def test_legacy_byok_pipeline_converts_to_byok_v2():
legacy = EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "byok"
assert config.byok.mode == "pipeline"
assert config.byok.pipeline.llm.provider == "openai"
assert config.byok.pipeline.tts.provider == "elevenlabs"
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
base = EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
)
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": {
"tts": {
"provider": "dograh",
"api_key": "mps-workflow",
"model": "default",
"voice": "default",
}
},
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
assert v2_override["mode"] == "dograh"
assert v2_override["dograh"]["api_key"] == "mps-workflow"
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
base = EffectiveAIModelConfiguration()
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": None,
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
@pytest.mark.asyncio
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
monkeypatch,
):
from api.routes import organization as organization_routes
legacy = EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
)
expected_response = OrganizationAIModelConfigurationResponse(
configuration={"version": 2, "mode": "dograh"},
effective_configuration={},
source="organization_v2",
)
class FakeValidator:
async def validate(self, *args, **kwargs):
return {"status": [{"model": "all", "message": "ok"}]}
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
upsert = AsyncMock()
migrate_workflows = AsyncMock()
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_routes,
"get_organization_ai_model_configuration_v2",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
organization_routes.db_client,
"get_user_configurations",
AsyncMock(return_value=legacy),
)
monkeypatch.setattr(
organization_routes,
"UserConfigurationValidator",
lambda: FakeValidator(),
)
monkeypatch.setattr(
organization_routes,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
monkeypatch.setattr(
organization_routes,
"upsert_organization_ai_model_configuration_v2",
upsert,
)
monkeypatch.setattr(
organization_routes,
"migrate_workflow_model_configurations_to_v2",
migrate_workflows,
)
monkeypatch.setattr(
organization_routes,
"_model_configuration_v2_response",
AsyncMock(return_value=expected_response),
)
user = SimpleNamespace(
id=7,
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_routes.migrate_model_configuration_v2(
force=False,
user=user,
)
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
upsert.assert_awaited_once()
migrate_workflows.assert_awaited_once_with(
organization_id=42,
fallback_user_config=legacy,
)
assert response == expected_response

View file

@ -0,0 +1,68 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services.auth import depends as auth_depends
@pytest.mark.asyncio
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
stack_user = {
"id": "stack-user-1",
"selected_team_id": "team-1",
"primary_email_verified": False,
}
user = SimpleNamespace(
id=7,
email=None,
provider_id="stack-user-1",
selected_organization_id=None,
)
organization = SimpleNamespace(id=42)
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
monkeypatch.setattr(
auth_depends.stackauth,
"get_user",
AsyncMock(return_value=stack_user),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_user_by_provider_id",
AsyncMock(return_value=(user, False)),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_or_create_organization_by_provider_id",
AsyncMock(return_value=(organization, True)),
)
monkeypatch.setattr(
auth_depends.db_client,
"add_user_to_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"update_user_selected_organization",
AsyncMock(),
)
monkeypatch.setattr(
auth_depends.db_client,
"get_user_configurations",
AsyncMock(return_value=existing_config),
)
monkeypatch.setattr(
auth_depends,
"ensure_hosted_mps_billing_account_v2",
ensure_billing,
)
result = await auth_depends.get_user(authorization="Bearer token")
assert result is user
assert result.selected_organization_id == 42
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")

View file

@ -0,0 +1,45 @@
from types import SimpleNamespace
from unittest.mock import patch
from api.services.configuration.registry import (
CARTESIA_TTS_MODELS,
CartesiaTTSConfiguration,
ServiceProviders,
)
from api.services.pipecat.service_factory import create_tts_service
def test_cartesia_tts_configuration_defaults_to_sonic_3_5():
config = CartesiaTTSConfiguration(api_key="test-key")
assert config.provider == ServiceProviders.CARTESIA
assert config.model == "sonic-3.5"
assert CARTESIA_TTS_MODELS == ["sonic-3.5", "sonic-3"]
def test_create_cartesia_tts_service_passes_selected_model():
user_config = SimpleNamespace(
tts=SimpleNamespace(
provider=ServiceProviders.CARTESIA.value,
api_key="test-key",
model="sonic-3.5",
voice="test-voice-id",
speed=1.0,
volume=1.0,
)
)
audio_config = SimpleNamespace(
transport_out_sample_rate=24000,
transport_in_sample_rate=16000,
)
with patch(
"api.services.pipecat.service_factory.CartesiaTTSService"
) as mock_service:
create_tts_service(user_config, audio_config)
assert mock_service.call_count == 1
kwargs = mock_service.call_args.kwargs
assert kwargs["api_key"] == "test-key"
assert kwargs["settings"].model == "sonic-3.5"
assert kwargs["settings"].voice == "test-voice-id"

View file

@ -1,31 +0,0 @@
from api.services.pricing.cost_calculator import cost_calculator
def test_cost_calculator():
"""Test function to verify cost calculation works"""
sample_usage = {
"llm": {
"OpenAILLMService#0|||gpt-4.1-mini": {
"prompt_tokens": 45380,
"completion_tokens": 496,
"total_tokens": 45876,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0,
}
},
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
"call_duration_seconds": 179,
}
result = cost_calculator.calculate_total_cost(sample_usage)
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
assert (
abs(
result["total"]
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
)
< 1e-10
)

View file

@ -0,0 +1,110 @@
import json
import pytest
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
from pipecat.frames.frames import TTSStartedFrame
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.openai.base_llm import OpenAILLMSettings
from websockets.protocol import State
class _FakeWebSocket:
def __init__(self):
self.state = State.OPEN
self.messages: list[dict] = []
async def send(self, message: str) -> None:
self.messages.append(json.loads(message))
async def close(self, *args, **kwargs) -> None:
self.state = State.CLOSED
def test_dograh_llm_uses_explicit_mps_correlation_id():
service = DograhLLMService(
api_key="mps-secret",
correlation_id="mps-corr-123",
settings=OpenAILLMSettings(model="default"),
)
service._start_metadata = {"workflow_run_id": 99}
params = service.build_chat_completion_params(
{
"messages": [],
"tools": OPENAI_NOT_GIVEN,
"tool_choice": OPENAI_NOT_GIVEN,
}
)
assert params["metadata"]["correlation_id"] == "mps-corr-123"
assert params["metadata"]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.stt.websocket_connect",
fake_connect,
)
service = DograhSTTService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=16000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
@pytest.mark.asyncio
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
fake_ws = _FakeWebSocket()
async def fake_connect(url, additional_headers):
return fake_ws
monkeypatch.setattr(
"pipecat.services.dograh.tts.websocket_connect",
fake_connect,
)
service = DograhTTSService(
api_key="mps-secret",
correlation_id="mps-corr-123",
sample_rate=24000,
)
service._start_metadata = {"workflow_run_id": 99}
await service._connect_websocket()
assert fake_ws.messages[0]["type"] == "config"
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[0]["mps_billing_version"] == "2"
async def _noop(*args, **kwargs):
return None
service.audio_context_available = lambda context_id: False
service.create_audio_context = _noop
service.start_ttfb_metrics = _noop
service.start_tts_usage_metrics = _noop
frames = []
async for frame in service.run_tts("hello", "ctx-1"):
frames.append(frame)
assert isinstance(frames[0], TTSStartedFrame)
assert fake_ws.messages[1]["type"] == "create_context"
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
assert fake_ws.messages[1]["mps_billing_version"] == "2"

View file

@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig:
"api.services.campaign.campaign_call_dispatcher.get_backend_endpoints",
AsyncMock(return_value=("https://example.com", None)),
),
patch(
"api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start",
AsyncMock(
return_value=SimpleNamespace(has_quota=True, error_message="")
),
),
):
mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1))
mock_db.create_workflow_run = AsyncMock(return_value=workflow_run)

View file

@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.xai.realtime import events
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
from api.services.pipecat.realtime.grok_realtime import (
DograhGrokRealtimeLLMService,
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
def test_factory_creates_dograh_grok_realtime_service():
user_config = UserConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=GrokRealtimeLLMConfiguration(
provider="grok_realtime",
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -1,10 +1,11 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes.user import router
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.depends import get_user
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
@ -14,14 +15,14 @@ from api.services.configuration.registry import (
)
def _make_test_app():
def _make_test_app(selected_organization_id=None):
app = FastAPI()
app.include_router(router)
mock_user = MagicMock()
mock_user.id = 1
mock_user.is_superuser = False
mock_user.selected_organization_id = None
mock_user.selected_organization_id = selected_organization_id
app.dependency_overrides[get_user] = lambda: mock_user
return app
@ -32,7 +33,7 @@ MASKED_KEY = mask_key(REAL_KEY) # "**************************cdef"
def _existing_openai_config():
return UserConfiguration(
return EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key=REAL_KEY,
@ -110,7 +111,7 @@ class TestMaskedKeyRejection:
client = TestClient(app)
new_key = "AIzaSyNewRealKey12345678"
updated = UserConfiguration(
updated = EffectiveAIModelConfiguration(
llm=GoogleLLMService(
provider="google",
api_key=new_key,
@ -177,7 +178,7 @@ class TestMaskedKeyRejection:
real_credentials = '{"type":"service_account","project_id":"demo-project"}'
masked_credentials = mask_key(real_credentials)
existing = UserConfiguration(
existing = EffectiveAIModelConfiguration(
llm=GoogleVertexLLMConfiguration(
provider="google_vertex",
api_key=None,
@ -210,3 +211,38 @@ class TestMaskedKeyRejection:
)
assert response.status_code == 200
def test_preference_only_update_does_not_validate_or_save_model_config(self):
"""Saving a test phone number through the legacy endpoint must not touch models."""
app = _make_test_app(selected_organization_id=11)
client = TestClient(app)
preferences = SimpleNamespace(test_phone_number=None, timezone=None)
with (
patch("api.routes.user.db_client") as mock_db,
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
patch(
"api.routes.user.get_organization_preferences",
new=AsyncMock(return_value=preferences),
),
patch(
"api.routes.user.upsert_organization_preferences",
new=AsyncMock(return_value=preferences),
) as upsert_preferences,
):
existing = _existing_openai_config()
mock_db.get_user_configurations = AsyncMock(return_value=existing)
mock_db.update_user_configuration = AsyncMock()
mock_db.get_organization_by_id = AsyncMock(return_value=None)
mock_validator.return_value.validate = AsyncMock()
response = client.put(
"/user/configurations/user",
json={"test_phone_number": "+15551234567"},
)
assert response.status_code == 200
assert response.json()["test_phone_number"] == "+15551234567"
mock_db.update_user_configuration.assert_not_called()
mock_validator.return_value.validate.assert_not_called()
upsert_preferences.assert_awaited_once()

View file

@ -87,3 +87,387 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
"Content-Type": "application/json",
},
)
@pytest.mark.asyncio
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"correlation_id": "mps-corr-123"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
client = MPSServiceKeyClient()
assert await client.create_correlation_id(
service_key="mps_sk_paid",
workflow_run_id=42,
) == {"correlation_id": "mps-corr-123"}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
{"workflow_run_id": 42},
{
"Authorization": "Bearer mps_sk_paid",
"Content-Type": "application/json",
},
)
]
@pytest.mark.asyncio
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_billing_account_status(organization_id=42) == {
"organization_id": 42,
"billing_mode": "v2",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/status",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(
200,
{
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.authorize_workflow_run_start(
organization_id=42,
workflow_run_id=88,
service_key="mps_sk_paid",
require_correlation_id=True,
minimum_credits=0.1,
metadata={"workflow_id": 7},
created_by="provider-123",
) == {
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/run-authorization",
{
"workflow_run_id": 88,
"service_key": "mps_sk_paid",
"require_correlation_id": True,
"minimum_credits": 0.1,
"metadata": {"workflow_id": 7},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, headers):
calls.append(("GET", url, headers))
return _Response(
200,
{
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.ensure_billing_account_v2(
organization_id=42,
created_by="provider-123",
) == {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": "0.0000",
"currency": "USD",
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/balance",
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def get(self, url, params, headers):
calls.append(("GET", url, params, headers))
return _Response(
200,
{
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.get_credit_ledger(
organization_id=42,
page=3,
limit=25,
) == {
"account": {"organization_id": 42},
"ledger_entries": [],
"total_count": 0,
"page": 3,
"limit": 25,
"total_pages": 0,
}
assert calls == [
(
"GET",
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
{"page": 3, "limit": 25},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
correlation_id="mps-corr-123",
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"correlation_id": "mps-corr-123",
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(200, {"metered": True})
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.report_platform_usage(
organization_id=42,
duration_seconds=87.0,
workflow_run_id=123,
metadata={"source": "workflow_run_completion"},
) == {"metered": True}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
{
"duration_seconds": 87.0,
"workflow_run_id": 123,
"metadata": {"source": "workflow_run_completion"},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]

View file

@ -0,0 +1,99 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.routes import organization_usage
def test_is_mps_billing_v2_depends_only_on_account_mode():
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
assert organization_usage._is_mps_billing_v2(None) is False
@pytest.mark.asyncio
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
get_status = AsyncMock(return_value={"billing_mode": "v2"})
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_billing_account_status",
get_status,
)
user = SimpleNamespace(provider_id="provider-123")
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
"billing_mode": "v2"
}
get_status.assert_awaited_once_with(
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
organization_usage,
"_get_mps_billing_account_status",
AsyncMock(return_value={"billing_mode": "v2"}),
)
get_ledger = AsyncMock(
return_value={
"account": {
"id": 7,
"organization_id": 42,
"billing_mode": "v2",
"cached_balance_credits": 250,
"currency": "USD",
},
"ledger_entries": [
{
"id": 99,
"entry_type": "grant",
"origin": "account_creation",
"credits_delta": 250,
"balance_after": 250,
"created_at": "2026-06-12T00:00:00Z",
}
],
"total_debits_credits": 75,
"total_count": 101,
"page": 3,
"limit": 25,
"total_pages": 5,
}
)
monkeypatch.setattr(
organization_usage.mps_service_key_client,
"get_credit_ledger",
get_ledger,
)
user = SimpleNamespace(
provider_id="provider-123",
selected_organization_id=42,
)
response = await organization_usage.get_billing_credits(
page=3,
limit=25,
user=user,
)
get_ledger.assert_awaited_once_with(
organization_id=42,
page=3,
limit=25,
created_by="provider-123",
)
assert response.billing_version == "v2"
assert response.total_credits_used == 75
assert response.total_count == 101
assert response.page == 3
assert response.limit == 25
assert response.total_pages == 5
assert response.ledger_entries[0].id == 99

View file

@ -0,0 +1,66 @@
from api.services.pipecat.pre_call_fetch import _extract_initial_context
class TestExtractInitialContext:
"""Tests for _extract_initial_context, the pre-call fetch response parser."""
def test_initial_context_nested_under_call_inbound(self):
"""The canonical `initial_context` key nested under `call_inbound`."""
response = {"call_inbound": {"initial_context": {"customer_name": "Jane"}}}
assert _extract_initial_context(response) == {"customer_name": "Jane"}
def test_initial_context_at_top_level(self):
"""The canonical `initial_context` key at the top level."""
response = {"initial_context": {"customer_name": "Jane"}}
assert _extract_initial_context(response) == {"customer_name": "Jane"}
def test_legacy_dynamic_variables_nested(self):
"""The legacy `dynamic_variables` key still works nested under `call_inbound`."""
response = {"call_inbound": {"dynamic_variables": {"customer_name": "Jane"}}}
assert _extract_initial_context(response) == {"customer_name": "Jane"}
def test_legacy_dynamic_variables_at_top_level(self):
"""The legacy `dynamic_variables` key still works at the top level."""
response = {"dynamic_variables": {"customer_name": "Jane"}}
assert _extract_initial_context(response) == {"customer_name": "Jane"}
def test_initial_context_takes_precedence_over_legacy(self):
"""When both keys are present, `initial_context` wins."""
response = {
"call_inbound": {
"initial_context": {"source": "new"},
"dynamic_variables": {"source": "legacy"},
}
}
assert _extract_initial_context(response) == {"source": "new"}
def test_falls_back_to_legacy_when_initial_context_not_a_dict(self):
"""A non-dict `initial_context` falls back to `dynamic_variables`."""
response = {
"initial_context": None,
"dynamic_variables": {"customer_name": "Jane"},
}
assert _extract_initial_context(response) == {"customer_name": "Jane"}
def test_nested_values_preserved(self):
"""Nested objects pass through untouched for dot-notation access."""
response = {
"call_inbound": {
"initial_context": {"customer": {"address": {"city": "LA"}}}
}
}
assert _extract_initial_context(response) == {
"customer": {"address": {"city": "LA"}}
}
def test_empty_when_no_known_keys(self):
"""A response with neither key yields an empty dict."""
assert _extract_initial_context({"call_inbound": {"agent_id": 1}}) == {}
def test_empty_when_call_inbound_missing(self):
"""No `call_inbound` and no top-level keys yields an empty dict."""
assert _extract_initial_context({}) == {}
def test_non_dict_vars_yield_empty(self):
"""A non-dict value under a known key yields an empty dict."""
assert _extract_initial_context({"initial_context": "nope"}) == {}

View file

@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner():
with (
patch("api.routes.public_agent.db_client") as mock_db,
patch(
"api.routes.public_agent.check_dograh_quota_by_user_id",
"api.routes.public_agent.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner():
)
assert response.status_code == 200
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
quota_mock.assert_awaited_once_with(
workflow_id=workflow.id,
workflow_run_id=501,
)
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
with (
patch("api.routes.public_agent.db_client") as mock_db,
patch(
"api.routes.public_agent.check_dograh_quota_by_user_id",
"api.routes.public_agent.authorize_workflow_run_start",
new=quota_mock,
),
patch(

View file

@ -0,0 +1,274 @@
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
from api.routes.public_embed import PublicEmbedCORSMiddleware, router
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.dograh.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix="/api/v1")
app.include_router(router, prefix="/api/v1")
client = TestClient(app, raise_server_exceptions=False)
_ACTIVE_TOKEN = SimpleNamespace(
id=10,
is_active=True,
expires_at=None,
allowed_domains=[],
workflow_id=1,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
_RESTRICTED_TOKEN = SimpleNamespace(
id=20,
is_active=True,
expires_at=None,
allowed_domains=["allowed.example.com"],
workflow_id=2,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
_LOCALHOST_TOKEN = SimpleNamespace(
id=30,
is_active=True,
expires_at=None,
allowed_domains=["localhost:3000", "localhost:3020"],
workflow_id=3,
created_by=7,
usage_limit=None,
usage_count=0,
settings={},
)
@pytest.fixture(autouse=True)
def _patch_db(monkeypatch):
async def _get_token(token):
if token == "valid":
return _ACTIVE_TOKEN
if token == "restricted":
return _RESTRICTED_TOKEN
if token == "localhost":
return _LOCALHOST_TOKEN
return None
async def _get_token_by_id(token_id):
if token_id == _ACTIVE_TOKEN.id:
return _ACTIVE_TOKEN
if token_id == _RESTRICTED_TOKEN.id:
return _RESTRICTED_TOKEN
if token_id == _LOCALHOST_TOKEN.id:
return _LOCALHOST_TOKEN
return None
async def _get_session(session_token):
if session_token == "session-valid":
return SimpleNamespace(embed_token_id=_ACTIVE_TOKEN.id, expires_at=None)
if session_token == "session-restricted":
return SimpleNamespace(embed_token_id=_RESTRICTED_TOKEN.id, expires_at=None)
return None
async def _create_workflow_run(**_kwargs):
return SimpleNamespace(id=123)
async def _noop(*_args, **_kwargs):
return None
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_token_by_token",
_get_token,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_token_by_id",
_get_token_by_id,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.get_embed_session_by_token",
_get_session,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.create_workflow_run",
_create_workflow_run,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.create_embed_session",
_noop,
)
monkeypatch.setattr(
"api.routes.public_embed.db_client.increment_embed_token_usage",
_noop,
)
monkeypatch.setattr("api.routes.public_embed.TURN_SECRET", "test-secret")
monkeypatch.setattr(
"api.routes.public_embed.generate_turn_credentials",
lambda _user_id: {
"username": "turn-user",
"password": "turn-password",
"ttl": 3600,
"uris": ["turn:example.com:3478"],
},
)
def _assert_embed_cors(resp, origin: str):
assert resp.headers.get("access-control-allow-origin") == origin
assert "origin" in {
value.strip().lower() for value in resp.headers.get("vary", "").split(",")
}
def test_options_config_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/config/valid",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_config_accepts_allowed_localhost_port():
origin = "http://localhost:3020"
resp = client.options(
"/api/v1/public/embed/config/localhost",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_config_rejects_unknown_token():
resp = client.options(
"/api/v1/public/embed/config/unknown",
headers={
"Origin": "https://mysite.vercel.app",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403
def test_options_config_rejects_disallowed_origin():
resp = client.options(
"/api/v1/public/embed/config/restricted",
headers={
"Origin": "https://notallowed.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403
def test_get_config_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.get(
"/api/v1/public/embed/config/valid",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_get_config_accepts_allowed_localhost_port():
origin = "http://localhost:3020"
resp = client.get(
"/api/v1/public/embed/config/localhost",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_get_config_rejects_unlisted_localhost_port():
resp = client.get(
"/api/v1/public/embed/config/localhost",
headers={"Origin": "http://localhost:3021"},
)
assert resp.status_code == 403
def test_get_config_rejects_disallowed_origin():
resp = client.get(
"/api/v1/public/embed/config/restricted",
headers={"Origin": "https://notallowed.example.com"},
)
assert resp.status_code == 403
def test_init_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.post(
"/api/v1/public/embed/init",
headers={"Origin": origin},
json={"token": "valid"},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_turn_credentials_includes_acao_header():
origin = "https://mysite.vercel.app"
resp = client.get(
"/api/v1/public/embed/turn-credentials/session-valid",
headers={"Origin": origin},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_init_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/init",
headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_turn_credentials_returns_acao_for_allowed_origin():
origin = "https://mysite.vercel.app"
resp = client.options(
"/api/v1/public/embed/turn-credentials/session-valid",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 200
_assert_embed_cors(resp, origin)
def test_options_turn_credentials_rejects_disallowed_origin():
resp = client.options(
"/api/v1/public/embed/turn-credentials/session-restricted",
headers={
"Origin": "https://notallowed.example.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 403

View file

@ -0,0 +1,369 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import quota_service
from api.services.configuration.registry import ServiceProviders
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
def _dograh_config(
api_key: str = "mps_sk_12345678",
*,
managed_service_version: int = 1,
):
return SimpleNamespace(
managed_service_version=managed_service_version,
llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key),
stt=None,
tts=None,
embeddings=None,
)
def _byok_config():
return SimpleNamespace(
managed_service_version=2,
llm=SimpleNamespace(provider="openai", api_key="sk-openai"),
stt=None,
tts=None,
embeddings=None,
)
def _workflow():
return SimpleNamespace(
id=7,
user_id=123,
organization_id=42,
workflow_configurations={"model_overrides": {}},
)
def _workflow_owner():
return SimpleNamespace(
id=123,
provider_id="provider-123",
)
def _actor():
return SimpleNamespace(
id=456,
provider_id="actor-456",
selected_organization_id=42,
)
def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None):
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_by_id",
AsyncMock(return_value=workflow or _workflow()),
)
monkeypatch.setattr(
quota_service.db_client,
"get_user_by_id",
AsyncMock(return_value=owner or _workflow_owner()),
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2(
monkeypatch,
):
get_config = AsyncMock(return_value=_dograh_config())
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
}
)
check_usage = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is True
get_config.assert_awaited_once_with(
user_id=123,
organization_id=42,
workflow_configurations={"model_overrides": {}},
)
authorize.assert_awaited_once_with(
organization_id=42,
workflow_run_id=None,
service_key=None,
require_correlation_id=False,
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by="provider-123",
metadata={"dograh_user_id": "123", "workflow_id": 7},
)
check_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing(
monkeypatch,
):
get_config = AsyncMock(return_value=_byok_config())
authorize = AsyncMock(
return_value={
"allowed": False,
"billing_mode": "v2",
"remaining_credits": "0.0000",
"error": "insufficient_credits",
}
)
check_usage = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is False
assert result.error_code == "insufficient_credits"
assert "/billing" in result.error_message
assert "founders@dograh.com" not in result.error_message
authorize.assert_awaited_once()
check_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_authorize_workflow_run_v1_uses_legacy_key_usage(
monkeypatch,
):
api_key = "mps_sk_12345678"
get_config = AsyncMock(return_value=_dograh_config(api_key))
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v1",
"remaining_credits": "0.0000",
}
)
check_usage = AsyncMock(
return_value={"total_credits_used": 500.0, "remaining_credits": 0.0}
)
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is False
assert result.error_code == "quota_exceeded"
assert "founders@dograh.com" in result.error_message
assert "/billing" not in result.error_message
authorize.assert_awaited_once()
check_usage.assert_awaited_once_with(
api_key,
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
monkeypatch,
):
api_key = "mps_sk_12345678"
workflow_run = SimpleNamespace(initial_context={"existing": "value"})
get_config = AsyncMock(
return_value=_dograh_config(api_key, managed_service_version=2)
)
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
}
)
update_workflow_run = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
quota_service.db_client,
"update_workflow_run",
update_workflow_run,
)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
AsyncMock(),
)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
workflow_run_id=88,
)
assert result.has_quota is True
authorize.assert_awaited_once_with(
organization_id=42,
workflow_run_id=88,
service_key=api_key,
require_correlation_id=True,
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by="provider-123",
metadata={"dograh_user_id": "123", "workflow_id": 7},
)
update_workflow_run.assert_awaited_once_with(
88,
initial_context={
"existing": "value",
MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123",
},
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
monkeypatch,
):
api_key = "mps_sk_12345678"
workflow_run = SimpleNamespace(initial_context={})
get_config = AsyncMock(
return_value=_dograh_config(api_key, managed_service_version=2)
)
hosted_authorize = AsyncMock()
check_usage = AsyncMock(
return_value={"total_credits_used": 1.0, "remaining_credits": 499.0}
)
create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"})
update_workflow_run = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
quota_service.db_client,
"update_workflow_run",
update_workflow_run,
)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
hosted_authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"create_correlation_id",
create_correlation,
)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
workflow_run_id=88,
)
assert result.has_quota is True
hosted_authorize.assert_not_awaited()
check_usage.assert_awaited_once_with(
api_key,
organization_id=None,
created_by="provider-123",
)
create_correlation.assert_awaited_once_with(
service_key=api_key,
workflow_run_id=88,
)
update_workflow_run.assert_awaited_once_with(
88,
initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"},
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch):
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
actor_user=SimpleNamespace(selected_organization_id=999),
)
assert result.has_quota is False
assert result.error_code == "workflow_not_found"

View file

@ -2,14 +2,14 @@
TDD tests for resolve_effective_config().
This function deep-merges workflow-level model_overrides onto the global
UserConfiguration. Fields not overridden inherit from global.
EffectiveAIModelConfiguration. Fields not overridden inherit from global.
Module under test: api.services.configuration.resolve
"""
import pytest
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
contains_masked_key,
mask_workflow_configurations,
@ -35,9 +35,9 @@ from api.services.configuration.resolve import (
@pytest.fixture
def global_config() -> UserConfiguration:
def global_config() -> EffectiveAIModelConfiguration:
"""A realistic global user configuration."""
return UserConfiguration(
return EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai", api_key="sk-global-llm", model="gpt-4.1"
),
@ -59,9 +59,9 @@ def global_config() -> UserConfiguration:
@pytest.fixture
def global_config_realtime() -> UserConfiguration:
def global_config_realtime() -> EffectiveAIModelConfiguration:
"""Global config with realtime enabled."""
return UserConfiguration(
return EffectiveAIModelConfiguration(
llm=OpenAILLMService(
provider="openai", api_key="sk-global-llm", model="gpt-4.1"
),
@ -302,7 +302,7 @@ class TestRealtimeOverride:
class TestOverrideOnNullGlobal:
def test_override_stt_when_global_is_none(self):
"""When global has no STT config, override creates one from scratch."""
config = UserConfiguration(
config = EffectiveAIModelConfiguration(
llm=OpenAILLMService(provider="openai", api_key="sk-key", model="gpt-4.1"),
stt=None,
tts=None,
@ -325,7 +325,7 @@ class TestOverrideOnNullGlobal:
def test_override_realtime_when_global_is_none(self):
"""Realtime section can be created from override even if global has none."""
config = UserConfiguration(
config = EffectiveAIModelConfiguration(
llm=OpenAILLMService(provider="openai", api_key="sk-key", model="gpt-4.1"),
is_realtime=False,
realtime=None,

View file

@ -1,4 +1,4 @@
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.workflow.run_usage_response import format_public_usage_info
def test_format_public_usage_info():

View file

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import ANY, AsyncMock, Mock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
)
assert response.status_code == 200
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
quota_mock.assert_awaited_once_with(
workflow_id=workflow.id,
workflow_run_id=501,
actor_user=ANY,
)
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
create_call = mock_db.create_workflow_run.await_args
@ -103,6 +107,61 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
assert initiate_kwargs["workflow_id"] == workflow.id
assert initiate_kwargs["user_id"] == workflow.user_id
assert "user_id=99" in initiate_kwargs["webhook_url"]
mock_db.get_user_configurations.assert_not_called()
def test_initiate_call_uses_organization_preference_phone_number():
app = _make_test_app()
client = TestClient(app)
workflow = _workflow()
provider = _provider()
quota_mock = AsyncMock(
return_value=SimpleNamespace(has_quota=True, error_message="")
)
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(
"api.routes.telephony.get_default_telephony_provider",
new=AsyncMock(return_value=provider),
),
patch(
"api.routes.telephony.get_backend_endpoints",
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
),
):
mock_db.get_user_configurations = AsyncMock(
return_value=SimpleNamespace(test_phone_number="+15550000000")
)
mock_db.get_configuration = Mock(
return_value=SimpleNamespace(value={"test_phone_number": "+15557654321"})
)
mock_db.get_default_telephony_configuration = AsyncMock(
return_value=SimpleNamespace(id=55)
)
mock_db.get_workflow = AsyncMock(return_value=workflow)
mock_db.create_workflow_run = AsyncMock(
return_value=SimpleNamespace(
id=501,
name="WR-TEL-OUT-00000001",
initial_context={},
)
)
mock_db.update_workflow_run = AsyncMock()
response = client.post(
"/telephony/initiate-call",
json={"workflow_id": workflow.id},
)
assert response.status_code == 200
assert provider.initiate_call.await_args.kwargs["to_number"] == "+15557654321"
mock_db.get_user_configurations.assert_not_called()
def test_initiate_call_rejects_existing_run_for_different_workflow():
@ -118,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(

View file

@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
from websockets.exceptions import ConnectionClosedError
from websockets.frames import Close
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
from api.services.pipecat.realtime.ultravox_realtime import (
_RESUMPTION_USER_MESSAGE,
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
def test_factory_creates_dograh_ultravox_realtime_service():
user_config = UserConfiguration(
effective_config = EffectiveAIModelConfiguration(
is_realtime=True,
realtime=UltravoxRealtimeLLMConfiguration(
provider="ultravox_realtime",
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
)
service = create_realtime_llm_service(
user_config,
effective_config,
audio_config=SimpleNamespace(),
)

View file

@ -0,0 +1,212 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import workflow_run_billing as workflow_run_billing_mod
from api.services.workflow_run_billing import (
report_completed_workflow_run_platform_usage,
report_workflow_run_platform_usage,
)
def _make_workflow_run():
return SimpleNamespace(
id=123,
workflow_id=456,
is_completed=True,
initial_context={"mps_correlation_id": "mps-corr-123"},
usage_info={"call_duration_seconds": 87},
workflow=SimpleNamespace(
organization_id=42,
user=SimpleNamespace(selected_organization_id=42),
),
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
monkeypatch,
):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id="mps-corr-123",
duration_seconds=None,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "mps_correlation",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_awaited_once_with(
organization_id=42,
correlation_id=None,
duration_seconds=87.0,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": workflow_run.workflow_id,
"duration_source": "dograh_usage_info",
},
)
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
workflow_run = _make_workflow_run()
get_status = AsyncMock(return_value={"billing_mode": "v1"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_awaited_once_with(organization_id=42)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
monkeypatch,
):
workflow_run = _make_workflow_run()
workflow_run.initial_context = {}
workflow_run.usage_info = {}
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
get_status.assert_not_awaited()
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
workflow_run = _make_workflow_run()
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
workflow_run = _make_workflow_run()
workflow_run.is_completed = False
report_usage = AsyncMock()
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_workflow_run_platform_usage(workflow_run)
report_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
workflow_run = _make_workflow_run()
get_run = AsyncMock(return_value=workflow_run)
get_status = AsyncMock(return_value={"billing_mode": "v2"})
report_usage = AsyncMock(return_value={"metered": True})
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
workflow_run_billing_mod.db_client,
"get_workflow_run_by_id",
get_run,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"get_billing_account_status",
get_status,
)
monkeypatch.setattr(
workflow_run_billing_mod.mps_service_key_client,
"report_platform_usage",
report_usage,
)
await report_completed_workflow_run_platform_usage(workflow_run.id)
get_run.assert_awaited_once_with(workflow_run.id)
report_usage.assert_awaited_once()

Some files were not shown because too many files have changed in this diff Show more