mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
Merge remote-tracking branch 'origin/main' into feat/user-onboarding
This commit is contained in:
commit
093e888ce4
148 changed files with 10908 additions and 2815 deletions
|
|
@ -1,3 +1,3 @@
|
|||
{
|
||||
".": "1.33.0"
|
||||
".": "1.34.0"
|
||||
}
|
||||
28
CHANGELOG.md
28
CHANGELOG.md
|
|
@ -1,5 +1,33 @@
|
|||
# Changelog
|
||||
|
||||
## 1.34.0 (2026-06-03)
|
||||
|
||||
<!-- Release notes generated using configuration in .github/release.yml at main -->
|
||||
|
||||
## What's Changed
|
||||
### Features
|
||||
* feat: add mcp guides for various topic and stages for bot building by @a6kme in https://github.com/dograh-hq/dograh/pull/380
|
||||
* feat: allow overriding base URL of OpenAI STT and TTS by @developer603 in https://github.com/dograh-hq/dograh/pull/377
|
||||
* feat: add Azure AI multi-provider support (TTS, STT, Embeddings, Realtime) by @vishaldhateria in https://github.com/dograh-hq/dograh/pull/381
|
||||
### Bug Fixes
|
||||
* fix: support object and array parameters in custom HTTP tools by @mvanhorn in https://github.com/dograh-hq/dograh/pull/373
|
||||
* fix(telephony): resolve transfer context via call-sid index instead of KEYS scan by @shiminshen in https://github.com/dograh-hq/dograh/pull/387
|
||||
* fix(webrtc): enforce embed allowed-domain policy on public signaling websocket by @shiminshen in https://github.com/dograh-hq/dograh/pull/388
|
||||
* fix: use runtime BACKEND_URL for proxying by @a6kme in https://github.com/dograh-hq/dograh/pull/411
|
||||
* fix: add CORS preflight handler and ACAO header for embed config endpoint by @nuthalapativarun in https://github.com/dograh-hq/dograh/pull/403
|
||||
### Other Changes
|
||||
* Add Sarvam LLM, update Sarvam STT models, expose usage_info on run detail by @abhaybabbar in https://github.com/dograh-hq/dograh/pull/351
|
||||
* fix: make email lookup case-insensitive in get_user_by_email by @developer603 in https://github.com/dograh-hq/dograh/pull/397
|
||||
|
||||
## New Contributors
|
||||
* @abhaybabbar made their first contribution in https://github.com/dograh-hq/dograh/pull/351
|
||||
* @mvanhorn made their first contribution in https://github.com/dograh-hq/dograh/pull/373
|
||||
* @developer603 made their first contribution in https://github.com/dograh-hq/dograh/pull/377
|
||||
* @vishaldhateria made their first contribution in https://github.com/dograh-hq/dograh/pull/381
|
||||
* @shiminshen made their first contribution in https://github.com/dograh-hq/dograh/pull/387
|
||||
|
||||
**Full Changelog**: https://github.com/dograh-hq/dograh/compare/dograh-v1.33.0...dograh-v1.34.0
|
||||
|
||||
## [1.33.0](https://github.com/dograh-hq/dograh/compare/dograh-v1.32.0...dograh-v1.33.0) (2026-05-31)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -75,13 +75,13 @@ An honest comparison on the axes that matter most to teams evaluating voice AI p
|
|||
##### Download and setup Dograh on your Local Machine
|
||||
|
||||
> **Note**
|
||||
> We collect anonymous usage data to improve the product. You can opt out by setting the `ENABLE_TELEMETRY` to `false` in the below command.
|
||||
> We collect anonymous usage data to improve the product. You can opt out by setting `ENABLE_TELEMETRY=false` before running the startup script.
|
||||
|
||||
> **Note**
|
||||
> If you wish to run the platform on a remote server instead, checkout our [Documentation](https://docs.dograh.com/deployment/docker#option-2:-remote-server-deployment)
|
||||
|
||||
```bash
|
||||
curl -o docker-compose.yaml https://raw.githubusercontent.com/dograh-hq/dograh/main/docker-compose.yaml && REGISTRY=ghcr.io/dograh-hq ENABLE_TELEMETRY=true docker compose up --pull always
|
||||
curl -o docker-compose.yaml https://raw.githubusercontent.com/dograh-hq/dograh/main/docker-compose.yaml && curl -o start_docker.sh https://raw.githubusercontent.com/dograh-hq/dograh/main/scripts/start_docker.sh && chmod +x start_docker.sh && ./start_docker.sh
|
||||
```
|
||||
|
||||
> **Note**
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
|
|||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# 1) Create the `quota_type` enum *before* we add the column that references it.
|
||||
|
|
@ -34,7 +37,12 @@ def upgrade() -> None:
|
|||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("period_start", sa.DateTime(), nullable=False),
|
||||
sa.Column("period_end", sa.DateTime(), nullable=False),
|
||||
sa.Column("quota_dograh_tokens", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"quota_dograh_tokens",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
sa.Column("used_dograh_tokens", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
|
|
@ -63,7 +71,11 @@ def upgrade() -> None:
|
|||
op.add_column(
|
||||
"organizations",
|
||||
sa.Column(
|
||||
"quota_type", quota_type_enum, nullable=False, server_default="monthly"
|
||||
"quota_type",
|
||||
quota_type_enum,
|
||||
nullable=False,
|
||||
server_default="monthly",
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
|
@ -73,6 +85,7 @@ def upgrade() -> None:
|
|||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
|
@ -82,10 +95,17 @@ def upgrade() -> None:
|
|||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("LEAST(EXTRACT(DAY FROM CURRENT_DATE)::int, 28)"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"organizations", sa.Column("quota_start_date", sa.DateTime(), nullable=True)
|
||||
"organizations",
|
||||
sa.Column(
|
||||
"quota_start_date",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"organizations",
|
||||
|
|
@ -94,6 +114,7 @@ def upgrade() -> None:
|
|||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -5,28 +5,38 @@ Revises: 6bd9f67ec994
|
|||
Create Date: 2026-06-02 07:58:00.002359
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '384be6596b36'
|
||||
down_revision: Union[str, None] = '6bd9f67ec994'
|
||||
revision: str = "384be6596b36"
|
||||
down_revision: Union[str, None] = "6bd9f67ec994"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.create_index('ix_users_email_lower', 'users', [sa.literal_column('lower(email)')], unique=True, postgresql_where=sa.text('email IS NOT NULL'))
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.create_index(
|
||||
"ix_users_email_lower",
|
||||
"users",
|
||||
[sa.literal_column("lower(email)")],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("email IS NOT NULL"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_users_email_lower', table_name='users', postgresql_where=sa.text('email IS NOT NULL'))
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.drop_index(
|
||||
"ix_users_email_lower",
|
||||
table_name="users",
|
||||
postgresql_where=sa.text("email IS NOT NULL"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ branch_labels: Union[str, Sequence[str], None] = None
|
|||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
DEPRECATED_QUOTA_COMMENT = "Deprecated. MPS owns quota and credit ledger state."
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
|
|
@ -26,7 +29,12 @@ def upgrade() -> None:
|
|||
)
|
||||
op.add_column(
|
||||
"organization_usage_cycles",
|
||||
sa.Column("quota_amount_usd", sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
"quota_amount_usd",
|
||||
sa.Float(),
|
||||
nullable=True,
|
||||
comment=DEPRECATED_QUOTA_COMMENT,
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -117,6 +117,15 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def _add_public_embed_cors_middleware() -> None:
|
||||
from api.routes.public_embed import PublicEmbedCORSMiddleware
|
||||
|
||||
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix=API_PREFIX)
|
||||
|
||||
|
||||
_add_public_embed_cors_middleware()
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# include subrouters here
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.db.base_client import BaseDBClient
|
|||
from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_clause
|
||||
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
|
||||
|
||||
class CampaignClient(BaseDBClient):
|
||||
|
|
@ -215,26 +216,9 @@ class CampaignClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
@ -662,7 +646,7 @@ class CampaignClient(BaseDBClient):
|
|||
async with self.async_session() as session:
|
||||
conditions = [
|
||||
WorkflowRunModel.is_completed.is_(True),
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"]
|
||||
WorkflowRunModel.usage_info["call_duration_seconds"]
|
||||
.as_string()
|
||||
.isnot(None),
|
||||
]
|
||||
|
|
@ -685,6 +669,7 @@ class CampaignClient(BaseDBClient):
|
|||
WorkflowRunModel.initial_context,
|
||||
WorkflowRunModel.gathered_context,
|
||||
WorkflowRunModel.cost_info,
|
||||
WorkflowRunModel.usage_info,
|
||||
WorkflowRunModel.public_access_token,
|
||||
)
|
||||
.where(*conditions)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class DBClient(
|
|||
- UserClient: handles user and user configuration operations
|
||||
- OrganizationClient: handles organization operations
|
||||
- OrganizationConfigurationClient: handles organization configuration operations
|
||||
- OrganizationUsageClient: handles organization usage and quota operations
|
||||
- OrganizationUsageClient: handles organization usage reporting aggregates
|
||||
- IntegrationClient: handles integration operations
|
||||
- WorkflowTemplateClient: handles workflow template operations
|
||||
- CampaignClient: handles campaign operations
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def get_workflow_run_order_clause(
|
|||
"""
|
||||
# Determine sort column
|
||||
if sort_by == "duration":
|
||||
sort_column = WorkflowRunModel.cost_info.op("->>")(
|
||||
sort_column = WorkflowRunModel.usage_info.op("->>")(
|
||||
"call_duration_seconds"
|
||||
).cast(Float)
|
||||
else:
|
||||
|
|
@ -43,7 +43,7 @@ def get_workflow_run_order_clause(
|
|||
ATTRIBUTE_FIELD_MAPPING = {
|
||||
"dateRange": "created_at",
|
||||
"dispositionCode": "gathered_context.mapped_call_disposition",
|
||||
"duration": "cost_info.call_duration_seconds",
|
||||
"duration": "usage_info.call_duration_seconds",
|
||||
"status": "is_completed",
|
||||
"tokenUsage": "cost_info.total_cost_usd",
|
||||
"runId": "id",
|
||||
|
|
@ -208,7 +208,7 @@ def apply_workflow_run_filters(
|
|||
min_val = value.get("min")
|
||||
max_val = value.get("max")
|
||||
|
||||
if field == "cost_info.call_duration_seconds":
|
||||
if field == "usage_info.call_duration_seconds":
|
||||
# Use ->> operator for compatibility with all PostgreSQL versions
|
||||
# (subscript [] only works in PostgreSQL 14+)
|
||||
duration_text = cast(WorkflowRunModel.usage_info, JSONB).op("->>")(
|
||||
|
|
|
|||
|
|
@ -97,22 +97,44 @@ class OrganizationModel(Base):
|
|||
provider_id = Column(String, unique=True, index=True, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Quota fields
|
||||
# Deprecated: MPS owns quota and credit ledger state.
|
||||
quota_type = Column(
|
||||
Enum("monthly", "annual", name="quota_type"),
|
||||
nullable=False,
|
||||
default="monthly",
|
||||
server_default=text("'monthly'::quota_type"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default=text("0"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_reset_day = Column(
|
||||
Integer, nullable=False, default=1, server_default=text("1")
|
||||
) # 1-28, only for monthly
|
||||
quota_start_date = Column(DateTime(timezone=True), nullable=True) # Only for annual
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=1,
|
||||
server_default=text("1"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_start_date = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
quota_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
server_default=text("false"),
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
|
||||
price_per_second_usd = Column(Float, nullable=True)
|
||||
|
|
@ -593,8 +615,9 @@ class WorkflowRunTextSessionModel(Base):
|
|||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
cycle.
|
||||
This model is used to track reporting aggregates for an organization for a given
|
||||
usage cycle. Quota fields on this model are deprecated; MPS owns quota and
|
||||
credit ledger state.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_usage_cycles"
|
||||
|
|
@ -603,14 +626,24 @@ class OrganizationUsageCycleModel(Base):
|
|||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
period_start = Column(DateTime(timezone=True), nullable=False)
|
||||
period_end = Column(DateTime(timezone=True), nullable=False)
|
||||
quota_dograh_tokens = Column(Integer, nullable=False)
|
||||
quota_dograh_tokens = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
used_dograh_tokens = Column(Float, nullable=False, default=0)
|
||||
total_duration_seconds = Column(
|
||||
Integer, nullable=False, default=0, server_default=text("0")
|
||||
)
|
||||
# New USD tracking fields
|
||||
used_amount_usd = Column(Float, nullable=True, default=0)
|
||||
quota_amount_usd = Column(Float, nullable=True)
|
||||
quota_amount_usd = Column(
|
||||
Float,
|
||||
nullable=True,
|
||||
comment="Deprecated. MPS owns quota and credit ledger state.",
|
||||
info={"deprecated": True},
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import joinedload
|
|||
from api.db.base_client import BaseDBClient
|
||||
from api.db.filters import apply_workflow_run_filters
|
||||
from api.db.models import (
|
||||
OrganizationConfigurationModel,
|
||||
OrganizationModel,
|
||||
OrganizationUsageCycleModel,
|
||||
UserConfigurationModel,
|
||||
|
|
@ -17,11 +18,12 @@ from api.db.models import (
|
|||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class OrganizationUsageClient(BaseDBClient):
|
||||
"""Client for managing organization usage and quota operations."""
|
||||
"""Client for managing organization usage reporting aggregates."""
|
||||
|
||||
async def get_or_create_current_cycle(
|
||||
self, organization_id: int, session=None
|
||||
|
|
@ -47,14 +49,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
self, organization_id: int, session, commit: bool
|
||||
) -> OrganizationUsageCycleModel:
|
||||
"""Internal implementation for get_or_create_current_cycle."""
|
||||
# Get organization to determine quota type
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# Calculate current period
|
||||
period_start, period_end = self._calculate_current_period(org)
|
||||
period_start, period_end = self._calculate_current_period()
|
||||
|
||||
# Try to get existing cycle
|
||||
cycle_result = await session.execute(
|
||||
|
|
@ -76,7 +71,8 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
organization_id=organization_id,
|
||||
period_start=period_start,
|
||||
period_end=period_end,
|
||||
quota_dograh_tokens=org.quota_dograh_tokens,
|
||||
# Deprecated non-null column retained for historical schema compatibility.
|
||||
quota_dograh_tokens=0,
|
||||
)
|
||||
# Handle concurrent inserts gracefully
|
||||
stmt = stmt.on_conflict_do_nothing(
|
||||
|
|
@ -100,95 +96,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
return cycle_result.scalar_one()
|
||||
|
||||
async def check_and_reserve_quota(
|
||||
self, organization_id: int, estimated_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if organization has sufficient quota and optionally reserve tokens.
|
||||
Returns True if quota is available, False otherwise.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = org_result.scalar_one_or_none()
|
||||
|
||||
if not org or not org.quota_enabled:
|
||||
# No quota enforcement if not enabled
|
||||
return True
|
||||
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Atomic check and update with row-level lock
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(
|
||||
and_(
|
||||
OrganizationUsageCycleModel.id == cycle.id,
|
||||
OrganizationUsageCycleModel.used_dograh_tokens
|
||||
+ estimated_tokens
|
||||
<= OrganizationUsageCycleModel.quota_dograh_tokens,
|
||||
)
|
||||
)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
|
||||
cycle_locked = result.scalar_one_or_none()
|
||||
if cycle_locked:
|
||||
# Update the usage atomically
|
||||
cycle_locked.used_dograh_tokens += estimated_tokens
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def update_usage_after_run(
|
||||
self,
|
||||
organization_id: int,
|
||||
actual_tokens: float,
|
||||
duration_seconds: float = 0,
|
||||
charge_usd: float | None = None,
|
||||
) -> None:
|
||||
"""Update usage after a workflow run completes with actual token count and duration.
|
||||
|
||||
This method is fully atomic and safe for concurrent access from multiple processes.
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get or create current cycle within the same session/transaction
|
||||
cycle = await self._get_or_create_current_cycle_impl(
|
||||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Acquire a row-level lock for atomic update
|
||||
result = await session.execute(
|
||||
select(OrganizationUsageCycleModel)
|
||||
.where(OrganizationUsageCycleModel.id == cycle.id)
|
||||
.with_for_update(skip_locked=False)
|
||||
)
|
||||
cycle_locked = result.scalar_one()
|
||||
|
||||
# Update usage atomically
|
||||
cycle_locked.used_dograh_tokens += actual_tokens
|
||||
cycle_locked.total_duration_seconds += int(round(duration_seconds))
|
||||
|
||||
# Update USD amount if provided
|
||||
if charge_usd is not None:
|
||||
if cycle_locked.used_amount_usd is None:
|
||||
cycle_locked.used_amount_usd = 0
|
||||
cycle_locked.used_amount_usd += charge_usd
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def get_current_usage(self, organization_id: int) -> dict:
|
||||
"""Get current period usage information."""
|
||||
"""Get current reporting-period usage information."""
|
||||
async with self.async_session() as session:
|
||||
# Get organization
|
||||
org_result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
|
|
@ -199,42 +109,19 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
organization_id, session, commit=False
|
||||
)
|
||||
|
||||
# Calculate next refresh date
|
||||
if org.quota_type == "monthly":
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
else: # annual
|
||||
next_refresh = cycle.period_end + relativedelta(days=1)
|
||||
|
||||
result = {
|
||||
"period_start": cycle.period_start.isoformat(),
|
||||
"period_end": cycle.period_end.isoformat(),
|
||||
"used_dograh_tokens": cycle.used_dograh_tokens,
|
||||
"quota_dograh_tokens": cycle.quota_dograh_tokens,
|
||||
"percentage_used": (
|
||||
round(
|
||||
(cycle.used_dograh_tokens / cycle.quota_dograh_tokens) * 100, 2
|
||||
)
|
||||
if cycle.quota_dograh_tokens > 0
|
||||
else 0
|
||||
),
|
||||
"next_refresh_date": next_refresh.date().isoformat(),
|
||||
"quota_enabled": org.quota_enabled,
|
||||
"total_duration_seconds": cycle.total_duration_seconds,
|
||||
}
|
||||
|
||||
# Add USD fields if organization has pricing
|
||||
if org.price_per_second_usd is not None:
|
||||
result["used_amount_usd"] = cycle.used_amount_usd or 0
|
||||
result["quota_amount_usd"] = cycle.quota_amount_usd
|
||||
result["currency"] = "USD"
|
||||
result["price_per_second_usd"] = org.price_per_second_usd
|
||||
|
||||
# Calculate percentage based on USD if available
|
||||
if cycle.quota_amount_usd and cycle.quota_amount_usd > 0:
|
||||
result["percentage_used"] = round(
|
||||
((cycle.used_amount_usd or 0) / cycle.quota_amount_usd) * 100, 2
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_usage_history(
|
||||
|
|
@ -254,7 +141,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
WorkflowRunModel.usage_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
|
@ -307,19 +194,8 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
total_tokens = 0
|
||||
total_duration_seconds = 0
|
||||
for run in runs:
|
||||
if run.cost_info:
|
||||
# Try to get dograh_token_usage first (new format)
|
||||
dograh_tokens = run.cost_info.get("dograh_token_usage", 0)
|
||||
# If not present, calculate from total_cost_usd (old format)
|
||||
if dograh_tokens == 0 and "total_cost_usd" in run.cost_info:
|
||||
dograh_tokens = round(
|
||||
float(run.cost_info["total_cost_usd"]) * 100, 2
|
||||
)
|
||||
# Get call duration
|
||||
call_duration = run.cost_info.get("call_duration_seconds", 0)
|
||||
else:
|
||||
dograh_tokens = 0
|
||||
call_duration = 0
|
||||
dograh_tokens = 0
|
||||
call_duration = (run.usage_info or {}).get("call_duration_seconds", 0)
|
||||
total_tokens += dograh_tokens
|
||||
total_duration_seconds += int(round(call_duration))
|
||||
|
||||
|
|
@ -393,13 +269,14 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
WorkflowRunModel.initial_context,
|
||||
WorkflowRunModel.gathered_context,
|
||||
WorkflowRunModel.cost_info,
|
||||
WorkflowRunModel.usage_info,
|
||||
WorkflowRunModel.public_access_token,
|
||||
)
|
||||
.join(WorkflowModel, WorkflowRunModel.workflow_id == WorkflowModel.id)
|
||||
.join(UserModel, WorkflowModel.user_id == UserModel.id)
|
||||
.where(
|
||||
UserModel.selected_organization_id == organization_id,
|
||||
WorkflowRunModel.cost_info.isnot(None),
|
||||
WorkflowRunModel.usage_info.isnot(None),
|
||||
)
|
||||
.order_by(WorkflowRunModel.created_at.desc())
|
||||
)
|
||||
|
|
@ -440,8 +317,29 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"""Get daily usage breakdown for an organization with pricing."""
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Get user timezone if user_id is provided
|
||||
# Get org timezone preference first, then fall back to legacy user config.
|
||||
user_timezone = "UTC" # Default timezone
|
||||
pref_result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id,
|
||||
OrganizationConfigurationModel.key.in_(
|
||||
[
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
pref_rows = pref_result.scalars().all()
|
||||
pref_by_key = {pref.key: pref for pref in pref_rows}
|
||||
pref_obj = pref_by_key.get(
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value
|
||||
) or pref_by_key.get(
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value
|
||||
)
|
||||
if pref_obj and pref_obj.value:
|
||||
user_timezone = pref_obj.value.get("timezone") or user_timezone
|
||||
|
||||
if user_id:
|
||||
config_result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
|
|
@ -450,11 +348,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
if config_obj and config_obj.configuration:
|
||||
user_config = UserConfiguration.model_validate(
|
||||
effective_config = EffectiveAIModelConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone:
|
||||
user_timezone = user_config.timezone
|
||||
if effective_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = effective_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
try:
|
||||
|
|
@ -473,7 +371,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
select(
|
||||
date_expr.label("date"),
|
||||
func.sum(
|
||||
WorkflowRunModel.cost_info["call_duration_seconds"].as_float()
|
||||
WorkflowRunModel.usage_info["call_duration_seconds"].as_float()
|
||||
).label("total_seconds"),
|
||||
func.count(WorkflowRunModel.id).label("call_count"),
|
||||
)
|
||||
|
|
@ -522,83 +420,11 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"currency": "USD",
|
||||
}
|
||||
|
||||
async def update_organization_quota(
|
||||
self,
|
||||
organization_id: int,
|
||||
quota_type: str,
|
||||
quota_dograh_tokens: int,
|
||||
quota_reset_day: Optional[int] = None,
|
||||
quota_start_date: Optional[datetime] = None,
|
||||
) -> OrganizationModel:
|
||||
"""Update organization quota settings."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(OrganizationModel).where(OrganizationModel.id == organization_id)
|
||||
)
|
||||
org = result.scalar_one()
|
||||
|
||||
org.quota_type = quota_type
|
||||
org.quota_dograh_tokens = quota_dograh_tokens
|
||||
org.quota_enabled = True
|
||||
|
||||
if quota_type == "monthly" and quota_reset_day:
|
||||
org.quota_reset_day = quota_reset_day
|
||||
elif quota_type == "annual" and quota_start_date:
|
||||
org.quota_start_date = quota_start_date
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
def _calculate_current_period(
|
||||
self, org: OrganizationModel
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current billing period based on organization settings."""
|
||||
def _calculate_current_period(self) -> tuple[datetime, datetime]:
|
||||
"""Calculate the current calendar-month reporting period."""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if org.quota_type == "monthly":
|
||||
# Find the start of the current billing month
|
||||
reset_day = org.quota_reset_day
|
||||
|
||||
# Handle month boundaries
|
||||
if now.day >= reset_day:
|
||||
period_start = now.replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
else:
|
||||
# Previous month
|
||||
period_start = (now - relativedelta(months=1)).replace(
|
||||
day=reset_day, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
# End is one month later minus 1 second
|
||||
period_end = (
|
||||
period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
)
|
||||
|
||||
else: # annual
|
||||
if not org.quota_start_date:
|
||||
# Default to calendar year
|
||||
period_start = now.replace(
|
||||
month=1, day=1, hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
else:
|
||||
# Find current annual period
|
||||
start_date = org.quota_start_date.replace(tzinfo=timezone.utc)
|
||||
years_diff = now.year - start_date.year
|
||||
|
||||
# Adjust for whether we've passed the anniversary
|
||||
if now.month < start_date.month or (
|
||||
now.month == start_date.month and now.day < start_date.day
|
||||
):
|
||||
years_diff -= 1
|
||||
|
||||
period_start = start_date + relativedelta(years=years_diff)
|
||||
period_end = (
|
||||
period_start + relativedelta(years=1) - relativedelta(seconds=1)
|
||||
)
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
period_end = period_start + relativedelta(months=1) - relativedelta(seconds=1)
|
||||
|
||||
return period_start, period_end
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
class UserClient(BaseDBClient):
|
||||
|
|
@ -65,7 +65,9 @@ class UserClient(BaseDBClient):
|
|||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_configurations(self, user_id: int) -> UserConfiguration:
|
||||
async def get_user_configurations(
|
||||
self, user_id: int
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
|
|
@ -74,10 +76,10 @@ class UserClient(BaseDBClient):
|
|||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
return UserConfiguration()
|
||||
return EffectiveAIModelConfiguration()
|
||||
|
||||
try:
|
||||
return UserConfiguration.model_validate(
|
||||
return EffectiveAIModelConfiguration.model_validate(
|
||||
{
|
||||
**configuration_obj.configuration,
|
||||
"last_validated_at": configuration_obj.last_validated_at,
|
||||
|
|
@ -90,11 +92,11 @@ class UserClient(BaseDBClient):
|
|||
f"Failed to validate user configuration for user {user_id}: {e}. "
|
||||
"Returning default configuration."
|
||||
)
|
||||
return UserConfiguration()
|
||||
return EffectiveAIModelConfiguration()
|
||||
|
||||
async def update_user_configuration(
|
||||
self, user_id: int, configuration: UserConfiguration
|
||||
) -> UserConfiguration:
|
||||
self, user_id: int, configuration: EffectiveAIModelConfiguration
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
|
|
@ -115,7 +117,9 @@ class UserClient(BaseDBClient):
|
|||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(configuration_obj)
|
||||
return UserConfiguration.model_validate(configuration_obj.configuration)
|
||||
return EffectiveAIModelConfiguration.model_validate(
|
||||
configuration_obj.configuration
|
||||
)
|
||||
|
||||
async def update_user_configuration_last_validated_at(self, user_id: int) -> None:
|
||||
async with self.async_session() as session:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from api.db.models import (
|
|||
)
|
||||
from api.enums import CallType, StorageBackend
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
|
||||
|
||||
class WorkflowRunClient(BaseDBClient):
|
||||
|
|
@ -312,26 +313,9 @@ class WorkflowRunClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info
|
||||
and "dograh_token_usage" in run.cost_info
|
||||
else round(
|
||||
float(run.cost_info.get("total_cost_usd", 0)) * 100,
|
||||
2,
|
||||
)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds") or 0)
|
||||
)
|
||||
if run.cost_info
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
"definition_id": run.definition_id,
|
||||
"initial_context": run.initial_context,
|
||||
"gathered_context": run.gathered_context,
|
||||
|
|
|
|||
|
|
@ -89,6 +89,11 @@ class OrganizationConfigurationKey(Enum):
|
|||
LANGFUSE_CREDENTIALS = (
|
||||
"LANGFUSE_CREDENTIALS" # Org-level Langfuse tracing credentials
|
||||
)
|
||||
MODEL_CONFIGURATION_V2 = (
|
||||
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
|
||||
)
|
||||
ORGANIZATION_PREFERENCES = "ORGANIZATION_PREFERENCES" # Org-level defaults such as timezone/test call number
|
||||
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
|
||||
|
||||
|
||||
class WorkflowStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[project]
|
||||
name = "dograh-api"
|
||||
version = "1.33.0"
|
||||
version = "1.34.0"
|
||||
description = "Backend API for Dograh voice AI platform"
|
||||
requires-python = ">=3.13,<3.14"
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunState
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
|
||||
router = APIRouter(prefix="/agent-stream")
|
||||
|
|
@ -67,19 +67,6 @@ async def agent_stream_websocket(
|
|||
await websocket.close(code=1008, reason="Workflow not found")
|
||||
return
|
||||
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
workflow.user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"agent-stream quota exceeded for user {workflow.user_id}: "
|
||||
f"{quota_result.error_message}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=1008, reason=quota_result.error_message or "Quota exceeded"
|
||||
)
|
||||
return
|
||||
|
||||
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
|
||||
workflow_run_name = f"WR-AGS-{numeric_suffix:08d}"
|
||||
call_id = params.get("callId") or params.get("CallSid")
|
||||
|
|
@ -108,6 +95,20 @@ async def agent_stream_websocket(
|
|||
set_current_run_id(workflow_run.id)
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"agent-stream quota exceeded for user {workflow.user_id}: "
|
||||
f"{quota_result.error_message}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=1008, reason=quota_result.error_message or "Quota exceeded"
|
||||
)
|
||||
return
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
|
||||
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import create_jwt_token, hash_password, verify_password
|
||||
|
||||
|
|
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
|
|||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(user.id, mps_config)
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create default configuration for OSS user", exc_info=True
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from api.services.auth.depends import get_user
|
|||
from api.services.campaign.runner import campaign_runner_service
|
||||
from api.services.campaign.source_sync import CampaignSourceSyncService
|
||||
from api.services.campaign.source_sync_factory import get_sync_service
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.reports import generate_campaign_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
|
@ -550,7 +550,10 @@ async def start_campaign(
|
|||
|
||||
# Check Dograh quota before starting campaign (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this campaign will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
@ -872,7 +875,10 @@ async def resume_campaign(
|
|||
|
||||
# Check Dograh quota before resuming campaign (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this campaign will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -369,6 +369,10 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import (
|
||||
AzureOpenAIEmbeddingService,
|
||||
|
|
@ -376,20 +380,29 @@ async def search_chunks(
|
|||
)
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
effective_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
if effective_config.embeddings:
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
|
|
@ -406,9 +419,7 @@ async def search_chunks(
|
|||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
|
|
|
|||
|
|
@ -1,15 +1,27 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.constants import (
|
||||
DEFAULT_CAMPAIGN_RETRY_CONFIG,
|
||||
DEFAULT_ORG_CONCURRENCY_LIMIT,
|
||||
DEPLOYMENT_MODE,
|
||||
)
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
)
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
from api.schemas.telephony_config import (
|
||||
TelephonyConfigRequest,
|
||||
TelephonyConfigurationCreateRequest,
|
||||
|
|
@ -26,8 +38,36 @@ from api.schemas.telephony_phone_number import (
|
|||
PhoneNumberUpdateRequest,
|
||||
ProviderSyncStatus,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import is_mask_of, mask_key
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_organization_ai_model_configuration_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_model_configurations_to_v2,
|
||||
upsert_organization_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.registry import (
|
||||
DOGRAH_STT_LANGUAGES,
|
||||
REGISTRY,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
from api.services.telephony.factory import get_telephony_provider_by_id
|
||||
|
|
@ -98,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
|
|||
telnyx_missing_webhook_public_key_count: int
|
||||
|
||||
|
||||
@router.get("/context", response_model=OrganizationContextResponse)
|
||||
async def get_current_organization_context(user: UserModel = Depends(get_user)):
|
||||
"""Return organization-scoped configuration signals owned by Dograh."""
|
||||
return await get_organization_context(user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/telephony-providers/metadata",
|
||||
response_model=TelephonyProvidersMetadataResponse,
|
||||
|
|
@ -159,6 +205,239 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI model configurations v2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[service_type].items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
|
||||
|
||||
async def _model_configuration_v2_response(
|
||||
*,
|
||||
user: UserModel,
|
||||
configuration: OrganizationAIModelConfigurationV2 | None = None,
|
||||
) -> OrganizationAIModelConfigurationResponse:
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
raw_configuration = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else resolved.organization_configuration
|
||||
)
|
||||
return OrganizationAIModelConfigurationResponse(
|
||||
configuration=mask_ai_model_configuration_v2(raw_configuration),
|
||||
effective_configuration=mask_user_config(resolved.effective),
|
||||
source=resolved.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/defaults")
|
||||
async def get_model_configuration_v2_defaults(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
byok_default_providers = {
|
||||
service: provider
|
||||
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
return {
|
||||
"dograh": {
|
||||
"voices": [DOGRAH_DEFAULT_VOICE],
|
||||
"speeds": list(DOGRAH_SPEED_OPTIONS),
|
||||
"languages": DOGRAH_STT_LANGUAGES,
|
||||
"defaults": {
|
||||
"voice": DOGRAH_DEFAULT_VOICE,
|
||||
"speed": 1.0,
|
||||
"language": DOGRAH_DEFAULT_LANGUAGE,
|
||||
},
|
||||
},
|
||||
"byok": {
|
||||
"pipeline": {
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"tts": _byok_provider_schemas(ServiceType.TTS),
|
||||
"stt": _byok_provider_schemas(ServiceType.STT),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
"realtime": {
|
||||
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def get_model_configuration_v2(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await _model_configuration_v2_response(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def save_model_configuration_v2(
|
||||
request: OrganizationAIModelConfigurationV2,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
|
||||
try:
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/migration-preview")
|
||||
async def preview_model_configuration_v2_migration(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
return {
|
||||
"configuration": mask_ai_model_configuration_v2(configuration),
|
||||
"effective_configuration": mask_user_config(
|
||||
compile_ai_model_configuration_v2(configuration)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-configurations/v2/migrate",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def migrate_model_configuration_v2(
|
||||
force: bool = Query(default=False),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
if existing is not None and not force:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Organization already has a v2 model configuration",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize MPS billing v2 account for organization {}: {}",
|
||||
organization_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to initialize MPS billing v2 account",
|
||||
)
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
await migrate_workflow_model_configurations_to_v2(
|
||||
organization_id=organization_id,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/preferences", response_model=OrganizationPreferences)
|
||||
async def get_preferences(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await get_organization_preferences(organization_id)
|
||||
|
||||
|
||||
@router.put("/preferences", response_model=OrganizationPreferences)
|
||||
async def save_preferences(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await upsert_organization_preferences(
|
||||
organization_id,
|
||||
request,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_model_configuration_preferences_legacy(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await get_preferences(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def save_model_configuration_preferences_legacy(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await save_preferences(request=request, user=user)
|
||||
|
||||
|
||||
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
|
||||
"""If the client re-submitted a masked sensitive field, restore the original."""
|
||||
for field_name in _sensitive_fields(provider):
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.reports import generate_usage_runs_report_csv
|
||||
from api.utils.artifacts import artifact_url
|
||||
|
|
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
|
|||
period_start: str
|
||||
period_end: str
|
||||
used_dograh_tokens: float
|
||||
quota_dograh_tokens: int
|
||||
percentage_used: float
|
||||
next_refresh_date: str
|
||||
quota_enabled: bool
|
||||
total_duration_seconds: int
|
||||
# New USD fields
|
||||
used_amount_usd: Optional[float] = None
|
||||
quota_amount_usd: Optional[float] = None
|
||||
currency: Optional[str] = None
|
||||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
|
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
|
|||
total_quota: float
|
||||
|
||||
|
||||
class MPSCreditPurchaseUrlResponse(BaseModel):
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class MPSBillingAccountResponse(BaseModel):
|
||||
id: int
|
||||
organization_id: int
|
||||
billing_mode: str
|
||||
cached_balance_credits: float
|
||||
currency: str
|
||||
|
||||
|
||||
class MPSCreditLedgerEntryResponse(BaseModel):
|
||||
id: int
|
||||
entry_type: str
|
||||
origin: Optional[str] = None
|
||||
credits_delta: float
|
||||
balance_after: float
|
||||
amount_minor: Optional[int] = None
|
||||
amount_currency: Optional[str] = None
|
||||
payment_order_id: Optional[int] = None
|
||||
metric_code: Optional[str] = None
|
||||
correlation_id: Optional[str] = None
|
||||
aggregation_key: Optional[str] = None
|
||||
usage_event_id: Optional[int] = None
|
||||
workflow_run_id: Optional[int] = None
|
||||
workflow_id: Optional[int] = None
|
||||
billable_quantity: Optional[float] = None
|
||||
quantity_unit: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class MPSBillingCreditsResponse(BaseModel):
|
||||
billing_version: Literal["legacy", "v2"]
|
||||
total_credits_used: float = 0.0
|
||||
remaining_credits: float = 0.0
|
||||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
total_count: int = 0
|
||||
page: int = 1
|
||||
limit: int = 50
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
def _optional_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
|
|||
|
||||
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
|
||||
async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
||||
"""Get current billing period usage for the user's organization."""
|
||||
"""Get current reporting-period usage for the user's organization."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
|
||||
|
|
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _get_mps_billing_account_status(
|
||||
user: UserModel, organization_id: int
|
||||
) -> Optional[dict]:
|
||||
return await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
|
||||
|
||||
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = float(usage.get("total_credits_used", 0.0))
|
||||
total_remaining = float(usage.get("remaining_credits", 0.0))
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="legacy",
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
total_count = int(ledger.get("total_count") or len(ledger_entries))
|
||||
response_limit = int(ledger.get("limit") or limit)
|
||||
total_pages = int(
|
||||
ledger.get("total_pages")
|
||||
or ((total_count + response_limit - 1) // response_limit)
|
||||
)
|
||||
workflow_ids_by_run_id: dict[int, int] = {}
|
||||
workflow_run_ids = {
|
||||
workflow_run_id
|
||||
for entry in ledger_entries
|
||||
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
|
||||
is not None
|
||||
}
|
||||
for workflow_run_id in workflow_run_ids:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if (
|
||||
workflow_run
|
||||
and workflow_run.workflow
|
||||
and workflow_run.workflow.organization_id == organization_id
|
||||
):
|
||||
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
|
||||
|
||||
balance = float(account.get("cached_balance_credits") or 0.0)
|
||||
total_debits = sum(
|
||||
abs(float(entry.get("credits_delta") or 0.0))
|
||||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
if ledger.get("total_debits_credits") is not None:
|
||||
total_debits = float(ledger["total_debits_credits"])
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
total_credits_used=total_debits,
|
||||
remaining_credits=balance,
|
||||
total_quota=balance + total_debits,
|
||||
account=MPSBillingAccountResponse(
|
||||
id=int(account["id"]),
|
||||
organization_id=int(account["organization_id"]),
|
||||
billing_mode=str(account["billing_mode"]),
|
||||
cached_balance_credits=balance,
|
||||
currency=str(account.get("currency") or "USD"),
|
||||
),
|
||||
ledger_entries=[
|
||||
MPSCreditLedgerEntryResponse(
|
||||
id=int(entry["id"]),
|
||||
entry_type=str(entry["entry_type"]),
|
||||
origin=entry.get("origin"),
|
||||
credits_delta=float(entry.get("credits_delta") or 0.0),
|
||||
balance_after=float(entry.get("balance_after") or 0.0),
|
||||
amount_minor=entry.get("amount_minor"),
|
||||
amount_currency=entry.get("amount_currency"),
|
||||
payment_order_id=entry.get("payment_order_id"),
|
||||
metric_code=entry.get("metric_code"),
|
||||
correlation_id=entry.get("correlation_id"),
|
||||
aggregation_key=entry.get("aggregation_key"),
|
||||
usage_event_id=_optional_int(entry.get("usage_event_id")),
|
||||
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
|
||||
workflow_id=workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None,
|
||||
billable_quantity=float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None,
|
||||
quantity_unit=entry.get("quantity_unit"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
total_count=total_count,
|
||||
page=int(ledger.get("page") or page),
|
||||
limit=response_limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch billing credits: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/mps-credits/purchase-url",
|
||||
response_model=MPSCreditPurchaseUrlResponse,
|
||||
)
|
||||
async def create_mps_credit_purchase_url(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credit purchases are not available in OSS mode",
|
||||
)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
assert organization_id is not None
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
"Credit purchases are available only for organizations using billing v2"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
session = await mps_service_key_client.create_credit_purchase_url(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
|
||||
billing_details={
|
||||
"source": "dograh_billing",
|
||||
"dograh_user_id": str(user.id),
|
||||
"dograh_provider_id": str(user.provider_id),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to create credit purchase URL",
|
||||
)
|
||||
|
||||
checkout_url = session.get("checkout_url")
|
||||
if not checkout_url:
|
||||
logger.error(f"MPS checkout session response missing checkout_url: {session}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="MPS checkout session response missing checkout_url",
|
||||
)
|
||||
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
FILTERS_DESCRIPTION = """\
|
||||
JSON-encoded array of filter objects. Each object has the shape:
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import TriggerState, WorkflowStatus
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.factory import (
|
||||
get_default_telephony_provider,
|
||||
get_telephony_provider_by_id,
|
||||
|
|
@ -179,14 +179,6 @@ async def _execute_resolved_target(
|
|||
"""Shared execution path once the target workflow has been resolved."""
|
||||
execution_user_id = _get_execution_user_id(target.workflow)
|
||||
|
||||
# Check Dograh quota using the workflow owner's config and model overrides.
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id,
|
||||
workflow_id=target.workflow.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Get telephony provider — either the caller-specified config (validated
|
||||
# against the workflow's org) or the org's default config.
|
||||
if request.telephony_configuration_id is not None:
|
||||
|
|
@ -268,6 +260,15 @@ async def _execute_resolved_target(
|
|||
f"to phone number {request.phone_number}"
|
||||
)
|
||||
|
||||
# Check Dograh quota after the run exists so hosted v2 can mint and store
|
||||
# the MPS correlation id before the provider starts the call.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=target.workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# 9. Construct webhook URL for telephony provider callback
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ They handle CORS, domain validation, and session management for embedded workflo
|
|||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
|
|
@ -16,6 +17,8 @@ from fastapi import (
|
|||
)
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
|
|
@ -27,6 +30,9 @@ from api.routes.turn_credentials import (
|
|||
|
||||
router = APIRouter(prefix="/public/embed")
|
||||
|
||||
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
|
||||
EMBED_CORS_MAX_AGE = "86400"
|
||||
|
||||
|
||||
class InitEmbedRequest(BaseModel):
|
||||
"""Request model for initializing an embed session"""
|
||||
|
|
@ -70,11 +76,9 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|||
# If no domains specified, allow all origins
|
||||
return True
|
||||
|
||||
# Extract domain from origin (remove protocol)
|
||||
if "://" in origin:
|
||||
domain = origin.split("://")[1].split("/")[0].split(":")[0]
|
||||
else:
|
||||
domain = origin
|
||||
domain, origin_port = _parse_origin_host_port(origin)
|
||||
if not domain:
|
||||
return False
|
||||
|
||||
# Normalize domain for www matching
|
||||
def normalize_www(d: str) -> tuple[str, str]:
|
||||
|
|
@ -87,16 +91,23 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|||
domain_variants = normalize_www(domain)
|
||||
|
||||
for allowed in allowed_domains:
|
||||
allowed = str(allowed).strip().lower()
|
||||
if allowed == "*":
|
||||
return True
|
||||
elif allowed.startswith("*."):
|
||||
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
|
||||
if not allowed_domain:
|
||||
continue
|
||||
if allowed_port is not None and allowed_port != origin_port:
|
||||
continue
|
||||
|
||||
if allowed_domain.startswith("*."):
|
||||
# Wildcard subdomain matching
|
||||
base_domain = allowed[2:]
|
||||
base_domain = allowed_domain[2:]
|
||||
if domain == base_domain or domain.endswith("." + base_domain):
|
||||
return True
|
||||
else:
|
||||
# Check both www and non-www versions
|
||||
allowed_variants = normalize_www(allowed)
|
||||
allowed_variants = normalize_www(allowed_domain)
|
||||
# If any variant of domain matches any variant of allowed, it's valid
|
||||
if any(
|
||||
dv in allowed_variants or av in domain_variants
|
||||
|
|
@ -108,6 +119,24 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
|
||||
candidate = value.strip().lower()
|
||||
if not candidate:
|
||||
return "", None
|
||||
|
||||
if "://" not in candidate and not candidate.startswith("//"):
|
||||
candidate = f"//{candidate}"
|
||||
|
||||
parsed = urlsplit(candidate)
|
||||
try:
|
||||
parsed_port = parsed.port
|
||||
except ValueError:
|
||||
parsed_port = None
|
||||
|
||||
port = str(parsed_port) if parsed_port is not None else None
|
||||
return (parsed.hostname or "").rstrip("."), port
|
||||
|
||||
|
||||
def generate_session_token() -> str:
|
||||
"""Generate a cryptographically secure session token"""
|
||||
return f"emb_session_{secrets.token_urlsafe(32)}"
|
||||
|
|
@ -121,8 +150,120 @@ def get_request_origin(request: Request) -> str:
|
|||
return origin
|
||||
|
||||
|
||||
def _cors_response(origin: str, methods: str) -> Response:
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": methods,
|
||||
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
|
||||
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
|
||||
"Vary": "Origin",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _allow_embed_origin(response: Response, origin: str) -> None:
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
vary = response.headers.get("Vary")
|
||||
if not vary:
|
||||
response.headers["Vary"] = "Origin"
|
||||
return
|
||||
|
||||
vary_values = {value.strip().lower() for value in vary.split(",")}
|
||||
if "origin" not in vary_values:
|
||||
response.headers["Vary"] = f"{vary}, Origin"
|
||||
|
||||
|
||||
async def _config_preflight_response(token: str, origin: str) -> Response:
|
||||
embed_token = await db_client.get_embed_token_by_token(token)
|
||||
if not embed_token or not embed_token.is_active:
|
||||
return Response(status_code=403)
|
||||
|
||||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
return Response(status_code=403)
|
||||
|
||||
return _cors_response(origin, "GET, OPTIONS")
|
||||
|
||||
|
||||
async def _turn_credentials_preflight_response(
|
||||
session_token: str, origin: str
|
||||
) -> Response:
|
||||
embed_session = await db_client.get_embed_session_by_token(session_token)
|
||||
if not embed_session:
|
||||
return Response(status_code=403)
|
||||
|
||||
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
|
||||
return Response(status_code=403)
|
||||
|
||||
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
|
||||
if not embed_token:
|
||||
return Response(status_code=403)
|
||||
|
||||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
return Response(status_code=403)
|
||||
|
||||
return _cors_response(origin, "GET, OPTIONS")
|
||||
|
||||
|
||||
async def build_public_embed_preflight_response(
|
||||
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
|
||||
) -> Response | None:
|
||||
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
|
||||
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
|
||||
|
||||
if path == f"{public_embed_prefix}/init":
|
||||
if requested_method.upper() != "POST":
|
||||
return Response(status_code=405)
|
||||
return _cors_response(origin, "POST, OPTIONS")
|
||||
|
||||
config_prefix = f"{public_embed_prefix}/config/"
|
||||
if path.startswith(config_prefix):
|
||||
if requested_method.upper() != "GET":
|
||||
return Response(status_code=405)
|
||||
token = path[len(config_prefix) :].split("/", 1)[0]
|
||||
return await _config_preflight_response(token, origin)
|
||||
|
||||
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
|
||||
if path.startswith(turn_credentials_prefix):
|
||||
if requested_method.upper() != "GET":
|
||||
return Response(status_code=405)
|
||||
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
|
||||
return await _turn_credentials_preflight_response(session_token, origin)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class PublicEmbedCORSMiddleware:
|
||||
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
|
||||
|
||||
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
|
||||
self.app = app
|
||||
self.api_prefix = api_prefix
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
requested_method = headers.get("access-control-request-method")
|
||||
|
||||
if origin and requested_method:
|
||||
response = await build_public_embed_preflight_response(
|
||||
scope.get("path", ""), origin, requested_method, self.api_prefix
|
||||
)
|
||||
if response is not None:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
@router.post("/init", response_model=InitEmbedResponse)
|
||||
async def initialize_embed_session(request: Request, init_request: InitEmbedRequest):
|
||||
async def initialize_embed_session(
|
||||
request: Request, init_request: InitEmbedRequest, response: Response
|
||||
):
|
||||
"""Initialize an embed session with token validation and domain checking.
|
||||
|
||||
This endpoint:
|
||||
|
|
@ -158,6 +299,9 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
|
|||
)
|
||||
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
||||
|
||||
if origin:
|
||||
_allow_embed_origin(response, origin)
|
||||
|
||||
# Create workflow run
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
|
|
@ -204,8 +348,19 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
|
|||
)
|
||||
|
||||
|
||||
@router.options("/config/{token}")
|
||||
async def options_embed_config(token: str, request: Request):
|
||||
"""Fallback OPTIONS handler for the embed config endpoint.
|
||||
|
||||
Browser preflights include Access-Control-Request-Method and are handled by
|
||||
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
|
||||
OPTIONS requests on the same validation path.
|
||||
"""
|
||||
return await _config_preflight_response(token, request.headers.get("origin", ""))
|
||||
|
||||
|
||||
@router.get("/config/{token}", response_model=EmbedConfigResponse)
|
||||
async def get_embed_config(token: str, request: Request):
|
||||
async def get_embed_config(token: str, request: Request, response: Response):
|
||||
"""Get embed configuration without creating a session.
|
||||
|
||||
This endpoint is used to fetch widget configuration for display purposes
|
||||
|
|
@ -226,6 +381,11 @@ async def get_embed_config(token: str, request: Request):
|
|||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
||||
|
||||
# Set CORS header explicitly; the global CORSMiddleware covers only
|
||||
# first-party origins; this endpoint is fetched by external embed sites.
|
||||
if origin:
|
||||
_allow_embed_origin(response, origin)
|
||||
|
||||
# Extract settings with defaults
|
||||
settings = embed_token.settings or {}
|
||||
|
||||
|
|
@ -243,24 +403,20 @@ async def get_embed_config(token: str, request: Request):
|
|||
|
||||
@router.options("/init")
|
||||
async def options_init(request: Request):
|
||||
"""Handle CORS preflight for init endpoint"""
|
||||
"""Fallback OPTIONS handler for init endpoint."""
|
||||
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
|
||||
# For init endpoint, we need to check the token in the request body
|
||||
# But OPTIONS requests don't have body, so we'll be permissive
|
||||
# The actual validation happens in the POST request
|
||||
origin = request.headers.get("origin", "*")
|
||||
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Origin",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
)
|
||||
return _cors_response(origin, "POST, OPTIONS")
|
||||
|
||||
|
||||
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
|
||||
async def get_public_turn_credentials(session_token: str, request: Request):
|
||||
async def get_public_turn_credentials(
|
||||
session_token: str, request: Request, response: Response
|
||||
):
|
||||
"""Get TURN credentials for an embed session.
|
||||
|
||||
This endpoint allows embedded widgets to obtain TURN server credentials
|
||||
|
|
@ -295,6 +451,9 @@ async def get_public_turn_credentials(session_token: str, request: Request):
|
|||
)
|
||||
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
||||
|
||||
if origin:
|
||||
_allow_embed_origin(response, origin)
|
||||
|
||||
# Check if TURN is configured
|
||||
if not TURN_SECRET:
|
||||
raise HTTPException(
|
||||
|
|
@ -316,63 +475,8 @@ async def get_public_turn_credentials(session_token: str, request: Request):
|
|||
|
||||
@router.options("/turn-credentials/{session_token}")
|
||||
async def options_turn_credentials(request: Request, session_token: str):
|
||||
"""Handle CORS preflight for TURN credentials endpoint"""
|
||||
origin = request.headers.get("origin", "*")
|
||||
|
||||
# Try to validate the session token and get allowed domains
|
||||
allowed_origin = origin
|
||||
try:
|
||||
embed_session = await db_client.get_embed_session_by_token(session_token)
|
||||
if embed_session:
|
||||
embed_token = await db_client.get_embed_token_by_id(
|
||||
embed_session.embed_token_id
|
||||
)
|
||||
if embed_token:
|
||||
# Check if origin is in allowed domains (empty means allow all)
|
||||
if validate_origin(origin, embed_token.allowed_domains or []):
|
||||
allowed_origin = origin
|
||||
else:
|
||||
allowed_origin = ""
|
||||
except Exception:
|
||||
# On error, be permissive for OPTIONS
|
||||
pass
|
||||
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": allowed_origin,
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.options("/config/{token}")
|
||||
async def options_config(request: Request, token: str):
|
||||
"""Handle CORS preflight for config endpoint"""
|
||||
# Get origin header
|
||||
origin = request.headers.get("origin", "*")
|
||||
|
||||
# Try to validate the token and get allowed domains
|
||||
allowed_origin = origin
|
||||
try:
|
||||
embed_token = await db_client.get_embed_token_by_token(token)
|
||||
if embed_token and embed_token.is_active:
|
||||
# Check if origin is in allowed domains
|
||||
if validate_origin(origin, embed_token.allowed_domains or []):
|
||||
allowed_origin = origin
|
||||
else:
|
||||
# If not allowed, don't include the origin
|
||||
allowed_origin = ""
|
||||
except Exception:
|
||||
# On error, be permissive for OPTIONS
|
||||
pass
|
||||
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": allowed_origin,
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
"""Fallback OPTIONS handler for TURN credentials endpoint."""
|
||||
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
|
||||
return await _turn_credentials_preflight_response(
|
||||
session_token, request.headers.get("origin", "")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
|
|||
from api.errors.telephony_errors import TelephonyError
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import (
|
||||
get_all_telephony_providers,
|
||||
|
|
@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
|
|||
workflow_run_id: int | None = None
|
||||
phone_number: str | None = None
|
||||
# Optional explicit telephony config to use for the test call. If omitted,
|
||||
# falls back to the user's per-user default (when set), then the org default.
|
||||
# falls back to the org default.
|
||||
telephony_configuration_id: int | None = None
|
||||
# Optional caller-ID phone number to dial out from. Must belong to the
|
||||
# resolved telephony configuration; otherwise the provider picks one.
|
||||
|
|
@ -82,7 +82,12 @@ async def initiate_call(
|
|||
"""Initiate a call using the configured telephony provider from web browser. This is
|
||||
supposed to be a test call method for the draft version of the agent."""
|
||||
|
||||
user_configuration = await db_client.get_user_configurations(user.id)
|
||||
from api.services.organization_preferences import get_organization_preferences
|
||||
|
||||
preferences = await get_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
db=db_client,
|
||||
)
|
||||
|
||||
# Resolve which telephony config to use: explicit request value, otherwise
|
||||
# the org's default outbound config.
|
||||
|
|
@ -116,13 +121,12 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
phone_number = request.phone_number or preferences.test_phone_number
|
||||
|
||||
if not phone_number:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Phone number must be provided in request or set in user "
|
||||
"configuration",
|
||||
detail="Phone number must be provided in request or set in organization preferences",
|
||||
)
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
|
|
@ -132,14 +136,6 @@ async def initiate_call(
|
|||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
execution_user_id = _get_execution_user_id(workflow)
|
||||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so the keys we will actually use are the ones checked).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
|
|
@ -182,6 +178,16 @@ async def initiate_call(
|
|||
)
|
||||
workflow_run_name = workflow_run.name
|
||||
|
||||
# Check Dograh quota after the run exists so hosted v2 can mint and store
|
||||
# the MPS correlation id before initiating the call.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Construct webhook URL based on provider type
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
|
||||
|
|
@ -735,19 +741,8 @@ async def handle_inbound_run(request: Request):
|
|||
TelephonyError.SIGNATURE_VALIDATION_FAILED
|
||||
)
|
||||
|
||||
# 4. Quota check (use the workflow's model_overrides if set).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# 5. Create workflow run + return provider-shaped response.
|
||||
# 5. Create workflow run + authorize quota before returning provider
|
||||
# stream instructions.
|
||||
workflow_run_id = await _create_inbound_workflow_run(
|
||||
workflow_id,
|
||||
user_id,
|
||||
|
|
@ -756,6 +751,17 @@ async def handle_inbound_run(request: Request):
|
|||
telephony_configuration_id=telephony_configuration_id,
|
||||
from_phone_number_id=phone_row.id,
|
||||
)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
websocket_url = (
|
||||
|
|
@ -870,20 +876,8 @@ async def handle_inbound_telephony(
|
|||
logger.error(f"Request validation failed: {error_type}")
|
||||
return provider_class.generate_validation_error_response(error_type)
|
||||
|
||||
# Check quota before processing (apply per-workflow model_overrides).
|
||||
# Create workflow run.
|
||||
user_id = workflow_context["user_id"]
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run_id = await _create_inbound_workflow_run(
|
||||
workflow_id,
|
||||
workflow_context["user_id"],
|
||||
|
|
@ -892,6 +886,17 @@ async def handle_inbound_telephony(
|
|||
telephony_configuration_id=workflow_context["telephony_configuration_id"],
|
||||
from_phone_number_id=workflow_context.get("from_phone_number_id"),
|
||||
)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# Generate response URLs
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from api.db.models import (
|
|||
UserModel,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.check_validity import (
|
||||
APIKeyStatusResponse,
|
||||
UserConfigurationValidator,
|
||||
|
|
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
|
|||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import REGISTRY, ServiceType
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
|
||||
|
|
@ -94,8 +101,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
async def get_user_configurations(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> UserConfigurationRequestResponseSchema:
|
||||
user_configurations = await db_client.get_user_configurations(user.id)
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
masked_config = mask_user_config(resolved_config.effective)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -121,34 +137,61 @@ async def update_user_configurations(
|
|||
|
||||
# Remove organization_pricing from incoming dict as it's read-only
|
||||
incoming_dict.pop("organization_pricing", None)
|
||||
preferences_update = {
|
||||
key: incoming_dict.pop(key)
|
||||
for key in ("test_phone_number", "timezone")
|
||||
if key in incoming_dict
|
||||
}
|
||||
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(existing_config, incoming_dict)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
if incoming_dict:
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(
|
||||
existing_config, incoming_dict
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
else:
|
||||
user_configurations = existing_config
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
if user.selected_organization_id and preferences_update:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if "test_phone_number" in preferences_update:
|
||||
preferences.test_phone_number = preferences_update["test_phone_number"]
|
||||
if "timezone" in preferences_update:
|
||||
preferences.timezone = preferences_update["timezone"]
|
||||
await upsert_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
preferences,
|
||||
)
|
||||
|
||||
# Return masked version of updated config
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -168,7 +211,11 @@ async def validate_user_configurations(
|
|||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> APIKeyStatusResponse:
|
||||
configurations = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
configurations = resolved_config.effective
|
||||
|
||||
if (
|
||||
configurations.last_validated_at
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import (
|
|||
register_ws_sender,
|
||||
unregister_ws_sender,
|
||||
)
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
|
||||
router = APIRouter(prefix="/ws")
|
||||
|
||||
|
|
@ -329,7 +329,11 @@ class SignalingManager:
|
|||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this workflow will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
# Send error response for quota issues
|
||||
await ws.send_json(
|
||||
|
|
|
|||
|
|
@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
|
|||
from api.db.models import UserModel
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_configurations,
|
||||
|
|
@ -32,12 +41,15 @@ from api.services.configuration.resolve import (
|
|||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.reports import generate_workflow_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.run_usage_response import (
|
||||
format_public_cost_info,
|
||||
format_public_usage_info,
|
||||
)
|
||||
from api.services.workflow.trigger_paths import (
|
||||
TriggerPathIssue,
|
||||
ensure_trigger_paths,
|
||||
|
|
@ -955,12 +967,74 @@ async def update_workflow(
|
|||
existing_def,
|
||||
)
|
||||
|
||||
# Validate model_overrides: resolve onto global config, then
|
||||
# run the same validator used by the user-configurations endpoint.
|
||||
# Also stamp the current global API key into the override so the override
|
||||
# remains functional if the global config later switches to a different provider.
|
||||
# Validate model overrides. v2 uses a complete workflow-level model
|
||||
# configuration; legacy v1 uses partial service overlays.
|
||||
workflow_configurations = request.workflow_configurations
|
||||
if workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
if workflow_configurations and workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
existing_draft = await db_client.get_draft_version(workflow_id)
|
||||
existing_configs = (
|
||||
existing_draft.workflow_configurations
|
||||
if existing_draft
|
||||
else existing_workflow.released_definition.workflow_configurations
|
||||
)
|
||||
existing_v2_override = (existing_configs or {}).get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
try:
|
||||
incoming_v2_override = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
workflow_configurations[
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
]
|
||||
)
|
||||
)
|
||||
existing_v2_override_config = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
existing_v2_override
|
||||
)
|
||||
if existing_v2_override
|
||||
else None
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
incoming_v2_override,
|
||||
existing_v2_override_config,
|
||||
)
|
||||
if existing_v2_override_config is None:
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
v2_override,
|
||||
resolved_config.organization_configuration,
|
||||
)
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
|
||||
effective = compile_ai_model_configuration_v2(v2_override)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except (ValidationError, ValueError) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
elif workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
@ -978,24 +1052,48 @@ async def update_workflow(
|
|||
workflow_configurations,
|
||||
existing_configs,
|
||||
)
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
effective_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
effective_config,
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
effective = resolve_effective_config(
|
||||
effective_config, enriched_overrides
|
||||
)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
compile_ai_model_configuration_v2(v2_override),
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
else:
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
if resolved_config.source == "organization_v2":
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
else:
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
|
||||
# Reject upfront if any new trigger path collides with another
|
||||
# workflow's trigger — keeps the workflow record from
|
||||
|
|
@ -1171,22 +1269,7 @@ async def get_workflow_run(
|
|||
"transcript_public_url": artifact_url(public_access_token, "transcript"),
|
||||
"recording_public_url": artifact_url(public_access_token, "recording"),
|
||||
"public_access_token": public_access_token,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info and "dograh_token_usage" in run.cost_info
|
||||
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds"))
|
||||
)
|
||||
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
"created_at": run.created_at,
|
||||
"definition_id": run.definition_id,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from pydantic import BaseModel, Field
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.auth.depends import get_user_with_selected_organization
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
TextChatSessionExecutionError,
|
||||
|
|
@ -96,14 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _require_selected_organization_id(user: UserModel) -> int:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
async def _ensure_text_chat_quota(
|
||||
user: UserModel,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
) -> None:
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
@ -114,9 +116,8 @@ async def _load_text_session_or_404(
|
|||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=organization_id
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
|
|
@ -158,11 +159,8 @@ async def _execute_pending_turn_response(
|
|||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
|
|
@ -172,12 +170,13 @@ async def create_text_chat_session(
|
|||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
organization_id=organization_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
set_current_run_id(workflow_run.id)
|
||||
await _ensure_text_chat_quota(user, workflow_id, workflow_run.id)
|
||||
|
||||
annotations = {
|
||||
"tester": {
|
||||
|
|
@ -220,7 +219,7 @@ async def create_text_chat_session(
|
|||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
|
@ -234,10 +233,10 @@ async def append_text_chat_message(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
await _ensure_text_chat_quota(user, workflow_id, run_id)
|
||||
|
||||
try:
|
||||
text_session = await append_text_chat_user_message(
|
||||
|
|
@ -264,7 +263,7 @@ async def rewind_text_chat_session(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
try:
|
||||
|
|
|
|||
198
api/schemas/ai_model_configuration.py
Normal file
198
api/schemas/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
DograhEmbeddingsConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
ServiceProviders,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
DOGRAH_SPEED_OPTIONS: tuple[float, ...] = (0.8, 1.0, 1.2)
|
||||
DOGRAH_DEFAULT_VOICE = "default"
|
||||
DOGRAH_DEFAULT_LANGUAGE = "multi"
|
||||
|
||||
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
managed_service_version: int | None = None
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
# Post-signup onboarding gate: set once the user submits or skips the
|
||||
# onboarding form, so it shows only once per user.
|
||||
onboarding_completed_at: datetime | None = None
|
||||
onboarding_skipped: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
||||
|
||||
class DograhManagedAIModelConfiguration(BaseModel):
|
||||
api_key: str
|
||||
voice: str = DOGRAH_DEFAULT_VOICE
|
||||
speed: float = Field(default=1.0)
|
||||
language: str = DOGRAH_DEFAULT_LANGUAGE
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_speed(self):
|
||||
if self.speed not in DOGRAH_SPEED_OPTIONS:
|
||||
allowed = ", ".join(str(speed) for speed in DOGRAH_SPEED_OPTIONS)
|
||||
raise ValueError(f"Dograh speed must be one of: {allowed}")
|
||||
return self
|
||||
|
||||
|
||||
class BYOKPipelineAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig
|
||||
tts: TTSConfig
|
||||
stt: STTConfig
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reject_dograh_providers(self):
|
||||
_reject_dograh_provider("llm", self.llm)
|
||||
_reject_dograh_provider("tts", self.tts)
|
||||
_reject_dograh_provider("stt", self.stt)
|
||||
_reject_dograh_provider("embeddings", self.embeddings)
|
||||
return self
|
||||
|
||||
|
||||
class BYOKRealtimeAIModelConfiguration(BaseModel):
|
||||
realtime: RealtimeConfig
|
||||
llm: LLMConfig
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reject_dograh_providers(self):
|
||||
_reject_dograh_provider("llm", self.llm)
|
||||
_reject_dograh_provider("embeddings", self.embeddings)
|
||||
return self
|
||||
|
||||
|
||||
class BYOKAIModelConfiguration(BaseModel):
|
||||
mode: Literal["pipeline", "realtime"]
|
||||
pipeline: BYOKPipelineAIModelConfiguration | None = None
|
||||
realtime: BYOKRealtimeAIModelConfiguration | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selected_mode(self):
|
||||
if self.mode == "pipeline" and self.pipeline is None:
|
||||
raise ValueError("byok.pipeline is required when byok.mode is pipeline")
|
||||
if self.mode == "realtime" and self.realtime is None:
|
||||
raise ValueError("byok.realtime is required when byok.mode is realtime")
|
||||
return self
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationV2(BaseModel):
|
||||
version: Literal[2] = 2
|
||||
mode: Literal["dograh", "byok"]
|
||||
dograh: DograhManagedAIModelConfiguration | None = None
|
||||
byok: BYOKAIModelConfiguration | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selected_mode(self):
|
||||
if self.mode == "dograh" and self.dograh is None:
|
||||
raise ValueError("dograh configuration is required when mode is dograh")
|
||||
if self.mode == "byok" and self.byok is None:
|
||||
raise ValueError("byok configuration is required when mode is byok")
|
||||
return self
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationResponse(BaseModel):
|
||||
configuration: dict | None
|
||||
effective_configuration: dict
|
||||
source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
|
||||
|
||||
def compile_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
if configuration.mode == "dograh":
|
||||
if configuration.dograh is None:
|
||||
raise ValueError("dograh configuration is required")
|
||||
return _compile_dograh_configuration(configuration.dograh)
|
||||
|
||||
if configuration.byok is None:
|
||||
raise ValueError("byok configuration is required")
|
||||
if configuration.byok.mode == "pipeline":
|
||||
if configuration.byok.pipeline is None:
|
||||
raise ValueError("byok.pipeline is required")
|
||||
pipeline = configuration.byok.pipeline
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=pipeline.llm,
|
||||
tts=pipeline.tts,
|
||||
stt=pipeline.stt,
|
||||
embeddings=pipeline.embeddings,
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
if configuration.byok.realtime is None:
|
||||
raise ValueError("byok.realtime is required")
|
||||
realtime = configuration.byok.realtime
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=realtime.llm,
|
||||
realtime=realtime.realtime,
|
||||
embeddings=realtime.embeddings,
|
||||
is_realtime=True,
|
||||
)
|
||||
|
||||
|
||||
def _compile_dograh_configuration(
|
||||
configuration: DograhManagedAIModelConfiguration,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
voice=configuration.voice,
|
||||
speed=configuration.speed,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
language=configuration.language,
|
||||
),
|
||||
embeddings=DograhEmbeddingsConfiguration(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
),
|
||||
is_realtime=False,
|
||||
managed_service_version=2,
|
||||
)
|
||||
|
||||
|
||||
def _reject_dograh_provider(section: str, service) -> None:
|
||||
if service is None:
|
||||
return
|
||||
if getattr(service, "provider", None) == ServiceProviders.DOGRAH:
|
||||
raise ValueError(f"BYOK {section} cannot use Dograh provider")
|
||||
6
api/schemas/organization_preferences.py
Normal file
6
api/schemas/organization_preferences.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OrganizationPreferences(BaseModel):
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
|
||||
class UserConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
realtime: RealtimeConfig | None = None
|
||||
is_realtime: bool = False
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
# Post-signup onboarding gate: set once the user submits or skips the
|
||||
# onboarding form, so it shows only once per user (server-side, cross-device).
|
||||
onboarding_completed_at: datetime | None = None
|
||||
onboarding_skipped: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def strip_incomplete_realtime_when_disabled(cls, data):
|
||||
"""Skip realtime validation when is_realtime is False and api_key is missing."""
|
||||
if isinstance(data, dict) and not data.get("is_realtime", False):
|
||||
realtime = data.get("realtime")
|
||||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException, Query, WebSocket
|
||||
from fastapi import Depends, Header, HTTPException, Query, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
|
@ -110,6 +111,19 @@ async def get_user(
|
|||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization.id,
|
||||
created_by=str(stack_user["id"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize hosted MPS billing account for "
|
||||
"organization {}",
|
||||
organization.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
|
|
@ -119,6 +133,19 @@ async def get_user(
|
|||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
|
||||
mps_config
|
||||
)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
|
|
@ -129,6 +156,14 @@ async def get_user(
|
|||
return user_model
|
||||
|
||||
|
||||
async def get_user_with_selected_organization(
|
||||
user: Annotated[UserModel, Depends(get_user)],
|
||||
) -> UserModel:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
|
|
@ -192,7 +227,7 @@ async def _handle_api_key_auth(api_key: str) -> UserModel:
|
|||
|
||||
async def create_user_configuration_with_mps_key(
|
||||
user_id: int, organization_id: int, user_provider_id: str
|
||||
) -> Optional[UserConfiguration]:
|
||||
) -> Optional[EffectiveAIModelConfiguration]:
|
||||
"""Create user configuration using MPS service key.
|
||||
|
||||
Args:
|
||||
|
|
@ -201,7 +236,7 @@ async def create_user_configuration_with_mps_key(
|
|||
user_provider_id: The user's provider ID (for created_by field)
|
||||
|
||||
Returns:
|
||||
UserConfiguration with MPS-provided API keys or None if failed
|
||||
EffectiveAIModelConfiguration with MPS-provided API keys or None if failed
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
@ -211,7 +246,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
|
|
@ -229,7 +264,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
|
|
@ -264,8 +299,8 @@ async def create_user_configuration_with_mps_key(
|
|||
"model": "default",
|
||||
},
|
||||
}
|
||||
user_config = UserConfiguration(**configuration)
|
||||
return user_config
|
||||
effective_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return effective_config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to get MPS service key: {response.status_code} - {response.text}"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from api.services.campaign.errors import (
|
|||
PhoneNumberPoolExhaustedError,
|
||||
)
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
|
|||
},
|
||||
)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
error_message = quota_result.error_message or "Quota exceeded"
|
||||
logger.warning(
|
||||
f"Campaign {campaign.id} quota check failed for workflow run "
|
||||
f"{workflow_run.id}: {error_message}"
|
||||
)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
is_completed=True,
|
||||
state=WorkflowRunState.COMPLETED.value,
|
||||
gathered_context={"error": error_message},
|
||||
)
|
||||
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
|
||||
if mapping:
|
||||
org_id, mapped_slot_id = mapping
|
||||
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
|
||||
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
|
||||
|
||||
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
|
||||
workflow_run.id
|
||||
)
|
||||
if from_number_mapping:
|
||||
fn_org_id, fn_number, fn_tcid = from_number_mapping
|
||||
await rate_limiter.release_from_number(
|
||||
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
|
||||
)
|
||||
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
|
||||
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Initiate call via telephony provider
|
||||
try:
|
||||
# Construct webhook URL with parameters
|
||||
|
|
|
|||
484
api/services/configuration/ai_model_configuration.py
Normal file
484
api/services/configuration/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
BYOKAIModelConfiguration,
|
||||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
mask_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAIModelConfiguration:
|
||||
effective: EffectiveAIModelConfiguration
|
||||
source: AIModelConfigurationSource
|
||||
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowAIModelConfigurationMigrationResult:
|
||||
workflow_count: int = 0
|
||||
definition_count: int = 0
|
||||
workflow_ids: list[int] | None = None
|
||||
|
||||
|
||||
async def get_resolved_ai_model_configuration(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
) -> ResolvedAIModelConfiguration:
|
||||
organization_configuration = await get_organization_ai_model_configuration_v2(
|
||||
organization_id
|
||||
)
|
||||
if organization_configuration is not None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=compile_ai_model_configuration_v2(organization_configuration),
|
||||
source="organization_v2",
|
||||
organization_configuration=organization_configuration,
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=EffectiveAIModelConfiguration(),
|
||||
source="empty",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user_id)
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=legacy,
|
||||
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
|
||||
)
|
||||
|
||||
|
||||
async def get_effective_ai_model_configuration_for_workflow(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
workflow_configurations: dict | None,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
workflow_configurations = workflow_configurations or {}
|
||||
v2_override = workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
if v2_override:
|
||||
return compile_ai_model_configuration_v2(
|
||||
OrganizationAIModelConfigurationV2.model_validate(v2_override)
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return resolve_effective_config(
|
||||
resolved_config.effective,
|
||||
workflow_configurations.get("model_overrides"),
|
||||
)
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_v2(
|
||||
organization_id: int | None,
|
||||
) -> OrganizationAIModelConfigurationV2 | None:
|
||||
if organization_id is None:
|
||||
return None
|
||||
row = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
)
|
||||
if row is None or not row.value:
|
||||
return None
|
||||
try:
|
||||
return OrganizationAIModelConfigurationV2.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration v2 for organization "
|
||||
f"{organization_id}: {exc}. Falling back to legacy configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_v2(
|
||||
organization_id: int,
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
configuration.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
||||
async def migrate_workflow_model_configurations_to_v2(
|
||||
*,
|
||||
organization_id: int,
|
||||
fallback_user_config: EffectiveAIModelConfiguration,
|
||||
) -> WorkflowAIModelConfigurationMigrationResult:
|
||||
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
|
||||
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
|
||||
workflow_updates: list[tuple[int, dict]] = []
|
||||
definition_updates: list[tuple[int, dict]] = []
|
||||
migrated_workflow_ids: set[int] = set()
|
||||
|
||||
for workflow in workflows:
|
||||
base_config = fallback_user_config
|
||||
if workflow.user_id is not None:
|
||||
if workflow.user_id not in owner_configs:
|
||||
owner_configs[
|
||||
workflow.user_id
|
||||
] = await db_client.get_user_configurations(workflow.user_id)
|
||||
base_config = owner_configs[workflow.user_id]
|
||||
|
||||
workflow_configs, workflow_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if workflow_changed:
|
||||
workflow_updates.append((workflow.id, workflow_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
for definition in workflow.definitions:
|
||||
definition_configs, definition_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
definition.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if definition_changed:
|
||||
definition_updates.append((definition.id, definition_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
if workflow_updates or definition_updates:
|
||||
async with db_client.async_session() as session:
|
||||
for workflow_id, workflow_configs in workflow_updates:
|
||||
await session.execute(
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.values(workflow_configurations=workflow_configs)
|
||||
)
|
||||
for definition_id, definition_configs in definition_updates:
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(WorkflowDefinitionModel.id == definition_id)
|
||||
.values(workflow_configurations=definition_configs)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return WorkflowAIModelConfigurationMigrationResult(
|
||||
workflow_count=len(migrated_workflow_ids),
|
||||
definition_count=len(definition_updates),
|
||||
workflow_ids=sorted(migrated_workflow_ids),
|
||||
)
|
||||
|
||||
|
||||
def migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations: dict | None,
|
||||
base_config: EffectiveAIModelConfiguration,
|
||||
) -> tuple[dict, bool]:
|
||||
if not isinstance(workflow_configurations, dict):
|
||||
return {}, False
|
||||
|
||||
migrated = copy.deepcopy(workflow_configurations)
|
||||
model_overrides = migrated.get("model_overrides")
|
||||
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
|
||||
if not isinstance(model_overrides, dict):
|
||||
if "model_overrides" in migrated:
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
return migrated, False
|
||||
|
||||
if not existing_v2_override:
|
||||
effective = resolve_effective_config(base_config, model_overrides)
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
|
||||
|
||||
def merge_ai_model_configuration_v2_secrets(
|
||||
incoming: OrganizationAIModelConfigurationV2,
|
||||
existing: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
if existing is None:
|
||||
return incoming
|
||||
|
||||
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
|
||||
existing_dict = existing.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
|
||||
incoming_dograh = incoming_dict.get("dograh") or {}
|
||||
existing_dograh = existing_dict.get("dograh") or {}
|
||||
incoming_key = incoming_dograh.get("api_key")
|
||||
existing_key = existing_dograh.get("api_key")
|
||||
if incoming_key and existing_key and contains_masked_key(incoming_key):
|
||||
incoming_dograh["api_key"] = resolve_masked_api_keys(
|
||||
incoming_key,
|
||||
existing_key,
|
||||
)
|
||||
|
||||
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
|
||||
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
|
||||
|
||||
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
|
||||
|
||||
|
||||
def check_for_masked_keys_in_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> None:
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_raise_if_masked_secret(data)
|
||||
|
||||
|
||||
def mask_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> dict | None:
|
||||
if configuration is None:
|
||||
return None
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_mask_secret_fields(data)
|
||||
return data
|
||||
|
||||
|
||||
def convert_legacy_ai_model_configuration_to_v2(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
dograh_key = _first_dograh_api_key(configuration)
|
||||
if dograh_key:
|
||||
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
|
||||
|
||||
if configuration.is_realtime:
|
||||
if configuration.realtime is None or configuration.llm is None:
|
||||
raise ValueError("Realtime legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="realtime",
|
||||
realtime=BYOKRealtimeAIModelConfiguration(
|
||||
realtime=configuration.realtime,
|
||||
llm=configuration.llm,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
configuration.llm is None
|
||||
or configuration.tts is None
|
||||
or configuration.stt is None
|
||||
):
|
||||
raise ValueError("Pipeline legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="pipeline",
|
||||
pipeline=BYOKPipelineAIModelConfiguration(
|
||||
llm=configuration.llm,
|
||||
tts=configuration.tts,
|
||||
stt=configuration.stt,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
def apply_managed_embeddings_base_url(
|
||||
*,
|
||||
provider: str | None,
|
||||
base_url: str | None,
|
||||
) -> str | None:
|
||||
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
|
||||
return dograh_embeddings_base_url()
|
||||
return base_url
|
||||
|
||||
|
||||
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
|
||||
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
|
||||
return
|
||||
incoming_mode = incoming_byok.get("mode")
|
||||
existing_mode = existing_byok.get("mode")
|
||||
if incoming_mode != existing_mode:
|
||||
return
|
||||
section_names = (
|
||||
("llm", "tts", "stt", "embeddings")
|
||||
if incoming_mode == "pipeline"
|
||||
else ("realtime", "llm", "embeddings")
|
||||
)
|
||||
incoming_container = incoming_byok.get(incoming_mode)
|
||||
existing_container = existing_byok.get(existing_mode)
|
||||
if not isinstance(incoming_container, dict) or not isinstance(
|
||||
existing_container, dict
|
||||
):
|
||||
return
|
||||
for section_name in section_names:
|
||||
incoming_section = incoming_container.get(section_name)
|
||||
existing_section = existing_container.get(section_name)
|
||||
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
|
||||
_merge_service_secret_fields(incoming_section, existing_section)
|
||||
|
||||
|
||||
async def _list_workflows_for_model_configuration_migration(
|
||||
organization_id: int,
|
||||
) -> list[WorkflowModel]:
|
||||
async with db_client.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.definitions))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
return list(result.scalars().unique().all())
|
||||
|
||||
|
||||
def _merge_service_secret_fields(incoming: dict, existing: dict):
|
||||
if (
|
||||
incoming.get("provider") is not None
|
||||
and existing.get("provider") is not None
|
||||
and incoming.get("provider") != existing.get("provider")
|
||||
):
|
||||
return
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing:
|
||||
continue
|
||||
incoming_secret = incoming.get(secret_field)
|
||||
existing_secret = existing[secret_field]
|
||||
if incoming_secret is None:
|
||||
incoming[secret_field] = existing_secret
|
||||
elif contains_masked_key(incoming_secret):
|
||||
incoming[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_masked_secret(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
|
||||
raise ValueError(
|
||||
f"The {key} appears to be masked. Please provide the actual "
|
||||
"value, not the masked value."
|
||||
)
|
||||
_raise_if_masked_secret(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_raise_if_masked_secret(item)
|
||||
|
||||
|
||||
def _mask_secret_fields(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_secret_fields(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_secret_fields(item)
|
||||
|
||||
|
||||
def _mask_secret_value(value):
|
||||
if isinstance(value, list):
|
||||
return [mask_key(item) for item in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
|
||||
return any(
|
||||
service is not None
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_any_dograh_legacy_configuration(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
dograh_key: str,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
speed = getattr(configuration.tts, "speed", 1.0)
|
||||
if speed not in DOGRAH_SPEED_OPTIONS:
|
||||
speed = 1.0
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key=dograh_key,
|
||||
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
|
||||
or DOGRAH_DEFAULT_VOICE,
|
||||
speed=speed,
|
||||
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
|
||||
or DOGRAH_DEFAULT_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
):
|
||||
if service is None or _provider(service) != ServiceProviders.DOGRAH:
|
||||
continue
|
||||
try:
|
||||
return _single_api_key(service)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _provider(service):
|
||||
return getattr(service, "provider", None)
|
||||
|
||||
|
||||
def _single_api_key(service) -> str:
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Expected exactly one API key")
|
||||
return keys[0]
|
||||
key = getattr(service, "api_key", None)
|
||||
if not key:
|
||||
raise ValueError("Expected an API key")
|
||||
return key
|
||||
|
|
@ -8,8 +8,8 @@ from groq import Groq
|
|||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
UserConfiguration,
|
||||
from api.schemas.ai_model_configuration import (
|
||||
EffectiveAIModelConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
|
@ -64,7 +64,7 @@ class UserConfigurationValidator:
|
|||
|
||||
async def validate(
|
||||
self,
|
||||
configuration: UserConfiguration,
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> APIKeyStatusResponse:
|
||||
|
|
@ -75,21 +75,21 @@ class UserConfigurationValidator:
|
|||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The rules are simple:
|
|||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ def contains_masked_key(value: str | list[str] | None) -> bool:
|
|||
return any(MASK_MARKER in k for k in keys)
|
||||
|
||||
|
||||
def check_for_masked_keys(config: "UserConfiguration") -> None:
|
||||
def check_for_masked_keys(config: "EffectiveAIModelConfiguration") -> None:
|
||||
"""Raise ValueError if any service in *config* still has a masked secret."""
|
||||
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
|
||||
service = getattr(config, field, None)
|
||||
|
|
@ -111,7 +111,7 @@ def resolve_masked_api_keys(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level helpers for UserConfiguration objects
|
||||
# High-level helpers for EffectiveAIModelConfiguration objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
|
|||
return data
|
||||
|
||||
|
||||
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
||||
def mask_user_config(config: EffectiveAIModelConfiguration) -> Dict[str, Any]:
|
||||
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
|
||||
|
||||
return {
|
||||
|
|
@ -155,21 +155,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
|||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
if isinstance(model_overrides, dict):
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
v2_override = masked.get("model_configuration_v2_override")
|
||||
if isinstance(v2_override, dict):
|
||||
_mask_nested_service_secrets(v2_override)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
def _mask_nested_service_secrets(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_nested_service_secrets(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_nested_service_secrets(item)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
|
|
@ -66,9 +66,9 @@ def _merge_service_secret_fields(
|
|||
|
||||
|
||||
def merge_user_configurations(
|
||||
existing: UserConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> UserConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
|
||||
existing: EffectiveAIModelConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new EffectiveAIModelConfiguration.
|
||||
|
||||
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
|
||||
extracted via Pydantic `model_dump`).
|
||||
|
|
@ -113,14 +113,14 @@ def merge_user_configurations(
|
|||
if "timezone" in incoming_partial:
|
||||
merged["timezone"] = incoming_partial["timezone"]
|
||||
|
||||
# Onboarding gate flags — overwrite only when supplied (set once on submit/skip).
|
||||
# Onboarding gate flags: overwrite only when supplied.
|
||||
if "onboarding_completed_at" in incoming_partial:
|
||||
merged["onboarding_completed_at"] = incoming_partial["onboarding_completed_at"]
|
||||
|
||||
if "onboarding_skipped" in incoming_partial:
|
||||
merged["onboarding_skipped"] = incoming_partial["onboarding_skipped"]
|
||||
|
||||
return UserConfiguration.model_validate(merged)
|
||||
return EffectiveAIModelConfiguration.model_validate(merged)
|
||||
|
||||
|
||||
def merge_workflow_configuration_secrets(
|
||||
|
|
|
|||
|
|
@ -911,7 +911,7 @@ class DograhTTSService(BaseTTSConfiguration):
|
|||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice.")
|
||||
|
||||
|
||||
CARTESIA_TTS_MODELS = ["sonic-3"]
|
||||
CARTESIA_TTS_MODELS = ["sonic-3.5", "sonic-3"]
|
||||
|
||||
|
||||
@register_tts
|
||||
|
|
@ -919,7 +919,7 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
|||
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
|
||||
model: str = Field(
|
||||
default="sonic-3",
|
||||
default="sonic-3.5",
|
||||
description="Cartesia TTS model.",
|
||||
json_schema_extra={"examples": CARTESIA_TTS_MODELS},
|
||||
)
|
||||
|
|
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
DograhEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
)
|
||||
|
||||
# Maps override key → (UserConfiguration field, ServiceType for registry lookup)
|
||||
# Maps override key → (EffectiveAIModelConfiguration field, ServiceType for registry lookup)
|
||||
_SECTION_MAP: dict[str, ServiceType] = {
|
||||
"llm": ServiceType.LLM,
|
||||
"tts": ServiceType.TTS,
|
||||
|
|
@ -36,7 +36,7 @@ _SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
|
|||
|
||||
def enrich_overrides_with_api_keys(
|
||||
model_overrides: dict,
|
||||
user_config: UserConfiguration,
|
||||
user_config: EffectiveAIModelConfiguration,
|
||||
) -> dict:
|
||||
"""Copy API keys from the global config into model_overrides where missing.
|
||||
|
||||
|
|
@ -74,9 +74,9 @@ def enrich_overrides_with_api_keys(
|
|||
|
||||
|
||||
def resolve_effective_config(
|
||||
user_config: UserConfiguration,
|
||||
user_config: EffectiveAIModelConfiguration,
|
||||
model_overrides: dict | None,
|
||||
) -> UserConfiguration:
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
"""Deep-merge workflow model_overrides onto global user config.
|
||||
|
||||
- If model_overrides is None or empty, returns a copy of user_config unchanged.
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
base_url: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
if default_headers:
|
||||
client_kwargs["default_headers"] = default_headers
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
|
|
|
|||
78
api/services/managed_model_services.py
Normal file
78
api/services/managed_model_services.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
||||
def uses_managed_model_services_v2(
|
||||
ai_model_config: EffectiveAIModelConfiguration | None,
|
||||
) -> bool:
|
||||
if (
|
||||
ai_model_config is None
|
||||
or getattr(ai_model_config, "managed_service_version", None) != 2
|
||||
):
|
||||
return False
|
||||
|
||||
return any(
|
||||
_is_dograh_service(getattr(ai_model_config, section_name, None))
|
||||
for section_name in ("llm", "tts", "stt", "embeddings")
|
||||
)
|
||||
|
||||
|
||||
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
|
||||
if not initial_context:
|
||||
return None
|
||||
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
|
||||
if correlation_id is None:
|
||||
return None
|
||||
return str(correlation_id)
|
||||
|
||||
|
||||
async def ensure_mps_correlation_id(
|
||||
*,
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
workflow_run_id: int,
|
||||
initial_context: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
existing = get_mps_correlation_id(initial_context)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires workflow run authorization before "
|
||||
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
|
||||
)
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
service = getattr(ai_model_config, section_name, None)
|
||||
if not _is_dograh_service(service):
|
||||
continue
|
||||
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if isinstance(api_key, str) and api_key:
|
||||
return api_key
|
||||
|
||||
return None
|
||||
23
api/services/mps_billing.py
Normal file
23
api/services/mps_billing.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
async def ensure_hosted_mps_billing_account_v2(
|
||||
organization_id: int,
|
||||
*,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Ensure hosted orgs have an MPS billing v2 account.
|
||||
|
||||
OSS deployments use legacy per-key quota accounting and do not create MPS
|
||||
billing accounts.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return None
|
||||
|
||||
return await mps_service_key_client.ensure_billing_account_v2(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
|
|||
Service keys are stored and managed entirely in MPS, not in the local database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -353,6 +354,278 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def create_credit_purchase_url(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
return_url: Optional[str] = None,
|
||||
billing_details: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create a short-lived MPS checkout URL for adding organization credits."""
|
||||
payload = {
|
||||
"created_by": created_by,
|
||||
"return_url": return_url,
|
||||
"billing_details": billing_details or {},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create MPS credit purchase URL: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create MPS credit purchase URL: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get the MPS v2 billing account balance and recent credit ledger."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"page": page, "limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS credit ledger: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS credit ledger: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_billing_account_status(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Get an existing MPS v2 billing account without creating one."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS billing account status: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS billing account status: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def ensure_billing_account_v2(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create or return the MPS v2 billing account for an organization."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to ensure MPS billing account v2: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to ensure MPS billing account v2: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
service_key: Optional[str] = None,
|
||||
require_correlation_id: bool = False,
|
||||
minimum_credits: float | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
|
||||
payload = {
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"service_key": service_key,
|
||||
"require_correlation_id": require_correlation_id,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if minimum_credits is not None:
|
||||
payload["minimum_credits"] = minimum_credits
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to authorize MPS workflow run start: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to authorize MPS workflow run start: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
service_key: str,
|
||||
workflow_run_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Mint a server-generated correlation ID for managed model services."""
|
||||
payload: dict[str, int] = {}
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {service_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create correlation ID: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create correlation ID: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def report_platform_usage(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
correlation_id: Optional[str] = None,
|
||||
duration_seconds: Optional[float] = None,
|
||||
workflow_run_id: int | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
max_attempts: int = 3,
|
||||
) -> dict:
|
||||
"""Report hosted Dograh platform usage for a completed workflow run."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise ValueError("OSS deployments must not report platform usage to MPS")
|
||||
if not correlation_id and duration_seconds is None:
|
||||
raise ValueError(
|
||||
"Platform usage reports require correlation_id or duration_seconds"
|
||||
)
|
||||
|
||||
payload: dict = {
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if correlation_id:
|
||||
payload["correlation_id"] = correlation_id
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
max_attempts = max(1, max_attempts)
|
||||
last_response: httpx.Response | None = None
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
response = await client.post(
|
||||
(
|
||||
f"{self.base_url}/api/v1/billing/accounts/"
|
||||
f"{organization_id}/platform-usage"
|
||||
),
|
||||
json=payload,
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
last_response = response
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
should_retry = (
|
||||
response.status_code == 409
|
||||
and "usage_not_ready" in response.text
|
||||
and attempt < max_attempts
|
||||
)
|
||||
if should_retry:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
"Failed to report platform usage: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to report platform usage: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
raise httpx.HTTPStatusError(
|
||||
"Failed to report platform usage",
|
||||
request=last_response.request,
|
||||
response=last_response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
|
|
|
|||
50
api/services/organization_context.py
Normal file
50
api/services/organization_context.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
|
||||
class OrganizationModelServicesContext(BaseModel):
|
||||
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
has_model_configuration_v2: bool
|
||||
managed_service_version: Optional[int] = None
|
||||
uses_managed_service_v2: bool
|
||||
|
||||
|
||||
class OrganizationContextResponse(BaseModel):
|
||||
organization_id: Optional[int] = None
|
||||
organization_provider_id: Optional[str] = None
|
||||
model_services: OrganizationModelServicesContext
|
||||
|
||||
|
||||
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
|
||||
organization_id = user.selected_organization_id
|
||||
organization = (
|
||||
await db_client.get_organization_by_id(organization_id)
|
||||
if organization_id
|
||||
else None
|
||||
)
|
||||
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
managed_service_version = resolved.effective.managed_service_version
|
||||
|
||||
return OrganizationContextResponse(
|
||||
organization_id=organization_id,
|
||||
organization_provider_id=organization.provider_id if organization else None,
|
||||
model_services=OrganizationModelServicesContext(
|
||||
config_source=resolved.source,
|
||||
has_model_configuration_v2=resolved.source == "organization_v2",
|
||||
managed_service_version=managed_service_version,
|
||||
uses_managed_service_v2=(
|
||||
resolved.source == "organization_v2" and managed_service_version == 2
|
||||
),
|
||||
),
|
||||
)
|
||||
62
api/services/organization_preferences.py
Normal file
62
api/services/organization_preferences.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from inspect import isawaitable
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
|
||||
|
||||
async def get_organization_preferences(
|
||||
organization_id: int | None,
|
||||
db=None,
|
||||
) -> OrganizationPreferences:
|
||||
if organization_id is None:
|
||||
return OrganizationPreferences()
|
||||
|
||||
db = db or db_client
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
)
|
||||
if row is None:
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
)
|
||||
return _parse_preferences(row.value if row is not None else None, organization_id)
|
||||
|
||||
|
||||
async def upsert_organization_preferences(
|
||||
organization_id: int,
|
||||
preferences: OrganizationPreferences,
|
||||
) -> OrganizationPreferences:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
preferences.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return preferences
|
||||
|
||||
|
||||
async def _get_configuration(db, organization_id: int, key: str):
|
||||
row = db.get_configuration(organization_id, key)
|
||||
if isawaitable(row):
|
||||
row = await row
|
||||
return row
|
||||
|
||||
|
||||
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
|
||||
if not value or not isinstance(value, dict):
|
||||
return OrganizationPreferences()
|
||||
try:
|
||||
return OrganizationPreferences.model_validate(value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid organization preferences for organization "
|
||||
f"{organization_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OrganizationPreferences()
|
||||
|
|
@ -15,6 +15,29 @@ from api.utils.credential_auth import build_auth_header
|
|||
PRE_CALL_FETCH_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
def _extract_initial_context(response_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pull the context variables out of a pre-call fetch response.
|
||||
|
||||
The canonical key is ``initial_context``. The legacy ``dynamic_variables``
|
||||
key is still accepted for backward compatibility, so existing endpoints
|
||||
keep working; ``initial_context`` takes precedence when both are present.
|
||||
|
||||
Either key may appear at the top level or nested under ``call_inbound``:
|
||||
{"call_inbound": {"initial_context": {...}}} | {"initial_context": {...}}
|
||||
{"call_inbound": {"dynamic_variables": {...}}} | {"dynamic_variables": {...}}
|
||||
"""
|
||||
container = response_data.get("call_inbound")
|
||||
if not isinstance(container, dict):
|
||||
container = response_data
|
||||
|
||||
for key in ("initial_context", "dynamic_variables"):
|
||||
value = container.get(key)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def execute_pre_call_fetch(
|
||||
*,
|
||||
url: str,
|
||||
|
|
@ -77,24 +100,16 @@ async def execute_pre_call_fetch(
|
|||
)
|
||||
return {}
|
||||
|
||||
# Extract dynamic_variables from Retell-compatible response
|
||||
# Supports: {call_inbound: {dynamic_variables: {...}}}
|
||||
# or: {dynamic_variables: {...}}
|
||||
dynamic_vars = {}
|
||||
call_inbound = response_data.get("call_inbound")
|
||||
if isinstance(call_inbound, dict):
|
||||
dynamic_vars = call_inbound.get("dynamic_variables", {})
|
||||
elif "dynamic_variables" in response_data:
|
||||
dynamic_vars = response_data["dynamic_variables"]
|
||||
|
||||
if not isinstance(dynamic_vars, dict):
|
||||
dynamic_vars = {}
|
||||
# Extract the variables to merge into initial_context. Prefers
|
||||
# the canonical `initial_context` key, falling back to the
|
||||
# legacy `dynamic_variables` key for backward compatibility.
|
||||
initial_context_vars = _extract_initial_context(response_data)
|
||||
|
||||
logger.info(
|
||||
f"Pre-call fetch: success ({response.status_code}), "
|
||||
f"dynamic_variables keys: {list(dynamic_vars.keys())}"
|
||||
f"initial_context keys: {list(initial_context_vars.keys())}"
|
||||
)
|
||||
return dynamic_vars
|
||||
return initial_context_vars
|
||||
else:
|
||||
logger.warning(
|
||||
f"Pre-call fetch: HTTP {response.status_code} - "
|
||||
|
|
|
|||
|
|
@ -162,15 +162,13 @@ async def run_pipeline_telephony(
|
|||
workflow_id: Workflow being executed.
|
||||
workflow_run_id: Workflow run row.
|
||||
user_id: Owner of the workflow.
|
||||
call_id: Provider call identifier (stored in cost_info for billing).
|
||||
call_id: Provider call identifier.
|
||||
transport_kwargs: Provider-specific kwargs forwarded to the transport
|
||||
factory (e.g. stream_sid + call_sid for Twilio).
|
||||
"""
|
||||
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
|
||||
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if workflow:
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
|
@ -195,14 +193,17 @@ async def run_pipeline_telephony(
|
|||
# Resolve effective user config here so the transport can tune its
|
||||
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
|
||||
# values into _run_pipeline so it doesn't fetch them again.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -272,15 +273,18 @@ async def run_pipeline_smallwebrtc(
|
|||
# Resolve workflow_run + effective user_config here so the transport can
|
||||
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
|
||||
# reuses these via kwargs so we don't fetch twice.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -334,7 +338,7 @@ async def _run_pipeline(
|
|||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
merged_call_context_vars = dict(workflow_run.initial_context or {})
|
||||
# If there is some extra call_context_vars, fold them in. Persistence
|
||||
# happens once below, after runtime_configuration is also resolved.
|
||||
if call_context_vars:
|
||||
|
|
@ -380,15 +384,31 @@ async def _run_pipeline(
|
|||
# Resolve model overrides from the version onto global user config (skip
|
||||
# when the caller already resolved it).
|
||||
if resolved_user_config is None:
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
else:
|
||||
user_config = resolved_user_config
|
||||
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=merged_call_context_vars,
|
||||
)
|
||||
if mps_correlation_id:
|
||||
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
|
|
@ -400,11 +420,23 @@ async def _run_pipeline(
|
|||
# Realtime services don't implement run_inference, so create a
|
||||
# separate text LLM for variable extraction and other out-of-band
|
||||
# inference calls.
|
||||
inference_llm = create_llm_service(user_config)
|
||||
inference_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
stt = create_stt_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
keyterms=keyterms,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
tts = create_tts_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = None
|
||||
|
||||
# Stamp the providers/models actually resolved for this run onto
|
||||
|
|
@ -508,10 +540,17 @@ async def _run_pipeline(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if user_config and user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
|
|
@ -679,7 +718,10 @@ async def _run_pipeline(
|
|||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
voicemail_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
|
|
|
|||
|
|
@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
|||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
user_config,
|
||||
audio_config: "AudioConfig",
|
||||
keyterms: list[str] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
):
|
||||
"""Create and return appropriate STT service based on user configuration
|
||||
|
||||
|
|
@ -160,6 +163,7 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
|
|
@ -286,7 +290,9 @@ def create_stt_service(
|
|||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
def create_tts_service(
|
||||
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
|
||||
):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
|
|
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
|
|
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
|
|||
model: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
correlation_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
|
|
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
|
|||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
|
|
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
def create_llm_service(user_config, correlation_id: str | None = None):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
model = user_config.llm.model
|
||||
|
|
@ -880,4 +889,10 @@ def create_llm_service(user_config):
|
|||
elif provider == ServiceProviders.SARVAM.value:
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
return create_llm_service_from_provider(
|
||||
provider,
|
||||
model,
|
||||
api_key,
|
||||
correlation_id=correlation_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
# Pricing Module
|
||||
|
||||
This module contains pricing models and registries for different AI services used in workflow cost calculations.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
pricing/
|
||||
├── __init__.py # Main module exports
|
||||
├── models.py # Base pricing model classes
|
||||
├── llm.py # LLM pricing configurations
|
||||
├── tts.py # TTS pricing configurations
|
||||
├── stt.py # STT pricing configurations
|
||||
├── registry.py # Combined pricing registry
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Pricing Models
|
||||
|
||||
### TokenPricingModel
|
||||
Used for LLM services that charge based on tokens:
|
||||
- `prompt_token_price`: Cost per prompt token
|
||||
- `completion_token_price`: Cost per completion token
|
||||
- `cache_read_discount`: Discount for cache read tokens (default 50%)
|
||||
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
|
||||
|
||||
### CharacterPricingModel
|
||||
Used for TTS services that charge based on character count:
|
||||
- `character_price`: Cost per character
|
||||
|
||||
### TimePricingModel
|
||||
Used for STT services that charge based on time:
|
||||
- `second_price`: Cost per second
|
||||
|
||||
## Adding New Pricing
|
||||
|
||||
### Adding a New LLM Model
|
||||
Edit `llm.py` and add the model to the appropriate provider:
|
||||
|
||||
```python
|
||||
ServiceProviders.OPENAI: {
|
||||
"new-model": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000,
|
||||
completion_token_price=Decimal("8.00") / 1000000,
|
||||
),
|
||||
# ... existing models
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a New Provider
|
||||
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
|
||||
2. The registry will automatically include them
|
||||
|
||||
### Adding a New Service Type
|
||||
1. Create a new pricing file (e.g., `image.py`)
|
||||
2. Define the pricing models
|
||||
3. Import and add to `registry.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The pricing registry is automatically imported and used by the cost calculator:
|
||||
|
||||
```python
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.workflow.cost_calculator import cost_calculator
|
||||
|
||||
# The cost calculator uses the pricing registry automatically
|
||||
result = cost_calculator.calculate_total_cost(usage_info)
|
||||
```
|
||||
|
||||
## Maintenance
|
||||
|
||||
- Update pricing when providers change their rates
|
||||
- All prices should use `Decimal` for precision
|
||||
- Include comments with current pricing from provider documentation
|
||||
- Test changes with existing test suite
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""
|
||||
Pricing module for workflow cost calculation.
|
||||
|
||||
This module contains pricing models and registries for different AI services.
|
||||
"""
|
||||
|
||||
from .registry import PRICING_REGISTRY
|
||||
|
||||
__all__ = ["PRICING_REGISTRY"]
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
"""
|
||||
Cost Calculator for Workflow Runs
|
||||
|
||||
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
|
||||
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
|
||||
|
||||
Features:
|
||||
- Token-based pricing for LLM services with cache optimization support
|
||||
- Character-based pricing for TTS services
|
||||
- Time-based pricing for STT services
|
||||
- Configurable pricing models that can be updated
|
||||
- Support for multiple providers and models
|
||||
- Automatic provider inference from model names
|
||||
- JSON serialization support for database storage
|
||||
|
||||
Usage:
|
||||
from api.tasks.cost_calculator import cost_calculator
|
||||
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"processor_name|||gpt-4o": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"total_tokens": 1500,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"processor_name|||aura-2-helena-en": 2000 # character count
|
||||
}
|
||||
}
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
print(f"Total cost: ${cost_breakdown['total']:.6f}")
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.pricing.models import (
|
||||
PricingModel,
|
||||
)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""Main cost calculator class"""
|
||||
|
||||
def __init__(self, pricing_registry: Dict = None):
|
||||
self.pricing_registry = pricing_registry or PRICING_REGISTRY
|
||||
|
||||
def get_pricing_model(
|
||||
self, service_type: str, provider: str, model: str
|
||||
) -> Optional[PricingModel]:
|
||||
"""Get pricing model for a specific service, provider, and model"""
|
||||
try:
|
||||
service_pricing = self.pricing_registry.get(service_type, {})
|
||||
|
||||
# Try to get pricing for the specific provider
|
||||
provider_pricing = service_pricing.get(provider, {})
|
||||
pricing_model = provider_pricing.get(model) or provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
if pricing_model:
|
||||
return pricing_model
|
||||
|
||||
# If not found, try the "default" provider for this service type
|
||||
default_provider_pricing = service_pricing.get("default", {})
|
||||
return default_provider_pricing.get(model) or default_provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
except (KeyError, AttributeError):
|
||||
return None
|
||||
|
||||
def calculate_llm_cost(
|
||||
self, provider: str, model: str, usage: Dict[str, int]
|
||||
) -> Decimal:
|
||||
"""Calculate cost for LLM usage"""
|
||||
pricing_model = self.get_pricing_model("llm", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(usage)
|
||||
|
||||
def calculate_tts_cost(
|
||||
self, provider: str, model: str, character_count: int
|
||||
) -> Decimal:
|
||||
"""Calculate cost for TTS usage"""
|
||||
pricing_model = self.get_pricing_model("tts", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(character_count)
|
||||
|
||||
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT usage"""
|
||||
pricing_model = self.get_pricing_model("stt", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(seconds)
|
||||
|
||||
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
|
||||
llm_cost_total = Decimal("0")
|
||||
tts_cost_total = Decimal("0")
|
||||
stt_cost_total = Decimal("0")
|
||||
|
||||
# Calculate LLM costs
|
||||
llm_usage = usage_info.get("llm", {})
|
||||
for key, usage in llm_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Try to determine provider from processor name or model
|
||||
provider = self._infer_provider_from_model(model, "llm")
|
||||
cost = self.calculate_llm_cost(provider, model, usage)
|
||||
llm_cost_total += cost
|
||||
|
||||
# Calculate TTS costs
|
||||
tts_usage = usage_info.get("tts", {})
|
||||
for key, character_count in tts_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Handle the case where model is "None" - infer from processor
|
||||
if model.lower() in ["none", "null", ""]:
|
||||
provider = self._infer_provider_from_processor(processor, "tts")
|
||||
model = "default" # Use default model for the provider
|
||||
else:
|
||||
provider = self._infer_provider_from_model(model, "tts")
|
||||
cost = self.calculate_tts_cost(provider, model, character_count)
|
||||
tts_cost_total += cost
|
||||
|
||||
# Calculate STT costs from explicit stt usage
|
||||
stt_usage = usage_info.get("stt", {})
|
||||
for key, seconds in stt_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
provider = self._infer_provider_from_model(model, "stt")
|
||||
cost = self.calculate_stt_cost(provider, model, seconds)
|
||||
stt_cost_total += cost
|
||||
|
||||
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
|
||||
|
||||
return {
|
||||
"llm_cost": float(llm_cost_total),
|
||||
"tts_cost": float(tts_cost_total),
|
||||
"stt_cost": float(stt_cost_total),
|
||||
"total": float(total_cost),
|
||||
}
|
||||
|
||||
def _parse_key(self, key) -> Tuple[str, str]:
|
||||
"""Parse key which is in format 'processor|||model'"""
|
||||
if isinstance(key, str) and "|||" in key:
|
||||
parts = key.split("|||", 1)
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
# Fallback for backwards compatibility or malformed keys
|
||||
return str(key), "unknown"
|
||||
|
||||
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
|
||||
"""Infer provider from model name"""
|
||||
if not model:
|
||||
return "unknown"
|
||||
|
||||
model_lower = model.lower()
|
||||
|
||||
# OpenAI models
|
||||
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq models
|
||||
if any(keyword in model_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Elevenlabs models
|
||||
if any(keyword in model_lower for keyword in ["eleven"]):
|
||||
return ServiceProviders.ELEVENLABS
|
||||
|
||||
# Deepgram models
|
||||
if any(
|
||||
keyword in model_lower
|
||||
for keyword in ["deepgram", "nova", "phonecall", "general"]
|
||||
):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
|
||||
"""Infer provider from processor name"""
|
||||
if not processor:
|
||||
return "unknown"
|
||||
|
||||
processor_lower = processor.lower()
|
||||
|
||||
# OpenAI processors
|
||||
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq processors
|
||||
if any(keyword in processor_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Deepgram processors
|
||||
if any(keyword in processor_lower for keyword in ["deepgram"]):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def update_pricing(
|
||||
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
|
||||
):
|
||||
"""Update pricing for a specific service/provider/model combination"""
|
||||
if service_type not in self.pricing_registry:
|
||||
self.pricing_registry[service_type] = {}
|
||||
if provider not in self.pricing_registry[service_type]:
|
||||
self.pricing_registry[service_type][provider] = {}
|
||||
self.pricing_registry[service_type][provider][model] = pricing_model
|
||||
|
||||
|
||||
# Global cost calculator instance
|
||||
cost_calculator = CostCalculator()
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
"""
|
||||
LLM pricing models for different providers.
|
||||
|
||||
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TokenPricingModel
|
||||
|
||||
# LLM pricing registry
|
||||
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-3.5-turbo": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
|
||||
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
|
||||
),
|
||||
"gpt-4": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
|
||||
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
|
||||
),
|
||||
"gpt-4.1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
|
||||
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-nano": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
|
||||
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
),
|
||||
"gpt-4.5-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
|
||||
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
|
||||
completion_token_price=Decimal("10.00")
|
||||
/ 1000000, # $10.00 per 1M tokens - FIXED
|
||||
),
|
||||
"gpt-4o-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
|
||||
),
|
||||
"gpt-4o-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"o1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
|
||||
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
|
||||
),
|
||||
"o1-pro": TokenPricingModel(
|
||||
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
|
||||
),
|
||||
"o1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o3": TokenPricingModel(
|
||||
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
|
||||
),
|
||||
"o3-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o4-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"computer-use-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
|
||||
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
|
||||
),
|
||||
"gpt-image-1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
|
||||
),
|
||||
"codex-mini-latest": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
|
||||
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
|
||||
),
|
||||
# Transcription models
|
||||
"gpt-4o-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
|
||||
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
),
|
||||
# TTS models with token-based pricing
|
||||
"gpt-4o-mini-tts": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("0")
|
||||
/ 1000000, # No completion tokens for TTS
|
||||
),
|
||||
},
|
||||
ServiceProviders.GROQ: {
|
||||
"llama-3.3-70b-versatile": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
|
||||
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
|
||||
),
|
||||
"deepseek-r1-distill-llama-70b": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
|
||||
completion_token_price=Decimal("0.00079") / 1000,
|
||||
),
|
||||
},
|
||||
ServiceProviders.AZURE: {
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("8.80")
|
||||
/ 1000000, # $1.60 per 1M tokens if using data zone
|
||||
)
|
||||
},
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
Base pricing models for different service types.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CostType(Enum):
|
||||
LLM_TOKENS = "llm_tokens"
|
||||
TTS_CHARACTERS = "tts_characters"
|
||||
STT_SECONDS = "stt_seconds"
|
||||
|
||||
|
||||
class PricingModel:
|
||||
"""Base class for pricing models"""
|
||||
|
||||
def calculate_cost(self, usage: Any) -> Decimal:
|
||||
"""Calculate cost based on usage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenPricingModel(PricingModel):
|
||||
"""Pricing model for token-based services (LLM)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_price: Decimal,
|
||||
completion_token_price: Decimal,
|
||||
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
|
||||
cache_creation_multiplier: Decimal = Decimal(
|
||||
"1.25"
|
||||
), # 25% premium for cache creation
|
||||
):
|
||||
self.prompt_token_price = prompt_token_price
|
||||
self.completion_token_price = completion_token_price
|
||||
self.cache_read_discount = cache_read_discount
|
||||
self.cache_creation_multiplier = cache_creation_multiplier
|
||||
|
||||
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
|
||||
"""Calculate cost for LLM token usage"""
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
|
||||
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
|
||||
|
||||
# Base cost
|
||||
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
|
||||
completion_cost = Decimal(completion_tokens) * self.completion_token_price
|
||||
|
||||
# Cache adjustments
|
||||
cache_read_savings = (
|
||||
Decimal(cache_read_tokens)
|
||||
* self.prompt_token_price
|
||||
* self.cache_read_discount
|
||||
)
|
||||
cache_creation_premium = (
|
||||
Decimal(cache_creation_tokens)
|
||||
* self.prompt_token_price
|
||||
* (self.cache_creation_multiplier - 1)
|
||||
)
|
||||
|
||||
total_cost = (
|
||||
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
|
||||
)
|
||||
return max(total_cost, Decimal("0")) # Ensure non-negative
|
||||
|
||||
|
||||
class CharacterPricingModel(PricingModel):
|
||||
"""Pricing model for character-based services (TTS)"""
|
||||
|
||||
def __init__(self, character_price: Decimal):
|
||||
self.character_price = character_price
|
||||
|
||||
def calculate_cost(self, character_count: int) -> Decimal:
|
||||
"""Calculate cost for TTS character usage"""
|
||||
return Decimal(character_count) * self.character_price
|
||||
|
||||
|
||||
class TimePricingModel(PricingModel):
|
||||
"""Pricing model for time-based services (STT)"""
|
||||
|
||||
def __init__(self, second_price: Decimal):
|
||||
self.second_price = second_price
|
||||
|
||||
def calculate_cost(self, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT time usage"""
|
||||
return Decimal(str(seconds)) * self.second_price
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Main pricing registry that combines all service type pricing models.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
||||
# Combined pricing registry for all service types
|
||||
PRICING_REGISTRY: Dict = {
|
||||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
STT (Speech-to-Text) pricing models for different providers.
|
||||
|
||||
Prices are per second for STT services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TimePricingModel
|
||||
|
||||
# STT pricing registry
|
||||
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
|
||||
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
|
||||
"default": TimePricingModel(Decimal("0.0077") / 60),
|
||||
},
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
|
||||
"default": TimePricingModel(Decimal("0.015") / 60),
|
||||
},
|
||||
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""
|
||||
TTS (Text-to-Speech) pricing models for different providers.
|
||||
|
||||
Prices are per character for TTS services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import CharacterPricingModel
|
||||
|
||||
# TTS pricing registry
|
||||
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
},
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
|
||||
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
},
|
||||
ServiceProviders.ELEVENLABS: {
|
||||
# 6400 usd per 250*1e6 characters
|
||||
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
|
||||
},
|
||||
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
|
||||
}
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider_for_run
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _get_pricing_organization(workflow_run):
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
organization_id = getattr(workflow, "organization_id", None)
|
||||
if organization_id is None and workflow and workflow.user:
|
||||
organization_id = workflow.user.selected_organization_id
|
||||
if organization_id is None:
|
||||
return None
|
||||
return await db_client.get_organization_by_id(organization_id)
|
||||
|
||||
|
||||
async def _build_usage_cost_snapshot(
|
||||
usage_info: dict | None,
|
||||
*,
|
||||
workflow_run=None,
|
||||
include_telephony_cost: bool = False,
|
||||
organization=None,
|
||||
calculated_at: str | None = None,
|
||||
) -> dict | None:
|
||||
if not usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return None
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
|
||||
if include_telephony_cost and workflow_run is not None:
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
total_cost_usd = Decimal(str(cost_breakdown["total"]))
|
||||
dograh_tokens = float(total_cost_usd * Decimal("100"))
|
||||
|
||||
if organization is None and workflow_run is not None:
|
||||
organization = await _get_pricing_organization(workflow_run)
|
||||
|
||||
charge_usd = None
|
||||
if organization and organization.price_per_second_usd:
|
||||
duration_seconds = usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = float(
|
||||
Decimal(str(duration_seconds))
|
||||
* Decimal(str(organization.price_per_second_usd))
|
||||
)
|
||||
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(total_cost_usd),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": calculated_at
|
||||
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
|
||||
}
|
||||
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = organization.price_per_second_usd
|
||||
|
||||
return cost_info
|
||||
|
||||
|
||||
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
|
||||
cost_info = await _build_usage_cost_snapshot(
|
||||
workflow_run.usage_info,
|
||||
workflow_run=workflow_run,
|
||||
include_telephony_cost=True,
|
||||
calculated_at=workflow_run.created_at.isoformat(),
|
||||
)
|
||||
if cost_info is None:
|
||||
return None
|
||||
return {
|
||||
**(workflow_run.cost_info or {}),
|
||||
**cost_info,
|
||||
}
|
||||
|
||||
|
||||
async def save_workflow_run_cost_info(
|
||||
workflow_run_id: int, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
|
||||
async def apply_workflow_run_usage_to_organization(
|
||||
workflow_run, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
|
||||
|
||||
async def apply_usage_delta_to_organization(
|
||||
workflow_run, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return None
|
||||
|
||||
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
|
||||
if cost_info is None:
|
||||
return None
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
return cost_info
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
try:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
await save_workflow_run_cost_info(workflow_run_id, cost_info)
|
||||
|
||||
try:
|
||||
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
|
||||
except Exception as e:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if org:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to update organization usage: {e}")
|
||||
# Don't fail the whole cost calculation if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
|
|
@ -5,15 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
get_dograh_service_api_key,
|
||||
uses_managed_model_services_v2,
|
||||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
|
||||
|
||||
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your Dograh credits. "
|
||||
"Please purchase more credits from /billing "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
|
|
@ -24,104 +47,359 @@ class QuotaCheckResult:
|
|||
error_code: str = ""
|
||||
|
||||
|
||||
async def check_dograh_quota(
|
||||
user: UserModel, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check if user has sufficient Dograh quota for making a call.
|
||||
|
||||
This function checks if the user is using any Dograh services (LLM, STT, TTS)
|
||||
and validates that they have sufficient credits remaining.
|
||||
|
||||
When ``workflow_id`` is provided, the workflow's per-workflow
|
||||
``model_overrides`` are merged onto the user's global config so the quota
|
||||
check runs against the credentials that will actually be used for the call
|
||||
(rather than always falling back to the user's defaults).
|
||||
|
||||
Args:
|
||||
user: The user to check quota for
|
||||
workflow_id: Optional workflow whose ``model_overrides`` should be
|
||||
applied when resolving the effective service config.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with has_quota=True if user has sufficient quota or
|
||||
is not using Dograh services, or has_quota=False with error_message
|
||||
if quota is insufficient.
|
||||
"""
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
model_overrides = (workflow.workflow_configurations or {}).get(
|
||||
"model_overrides"
|
||||
|
||||
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="insufficient_credits",
|
||||
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _service_uses_dograh(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _dograh_api_keys(user_config: Any) -> set[str]:
|
||||
api_keys: set[str] = set()
|
||||
for section_name in ("llm", "stt", "tts", "embeddings"):
|
||||
service = getattr(user_config, section_name, None)
|
||||
if not _service_uses_dograh(service):
|
||||
continue
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
all_api_keys = [
|
||||
api_key
|
||||
for api_key in service.get_all_api_keys()
|
||||
if isinstance(api_key, str) and api_key
|
||||
]
|
||||
if all_api_keys:
|
||||
api_keys.update(all_api_keys)
|
||||
continue
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if api_key:
|
||||
api_keys.add(api_key)
|
||||
return api_keys
|
||||
|
||||
|
||||
async def _store_run_correlation_id(
|
||||
workflow_run_id: int | None,
|
||||
correlation_id: str | None,
|
||||
) -> None:
|
||||
if not workflow_run_id or not correlation_id:
|
||||
return
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Could not store MPS correlation id for missing workflow run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
initial_context = dict(workflow_run.initial_context or {})
|
||||
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
|
||||
return
|
||||
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
|
||||
async def _authorize_hosted_workflow_run_start(
|
||||
*,
|
||||
workflow_owner: UserModel,
|
||||
organization_id: int | None,
|
||||
workflow_id: int | None,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> tuple[QuotaCheckResult, bool]:
|
||||
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
|
||||
if DEPLOYMENT_MODE == "oss" or organization_id is None:
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
requires_correlation = bool(
|
||||
workflow_run_id and uses_managed_model_services_v2(user_config)
|
||||
)
|
||||
service_key = (
|
||||
get_dograh_service_api_key(user_config) if requires_correlation else None
|
||||
)
|
||||
if requires_correlation and not service_key:
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
try:
|
||||
authorization = await mps_service_key_client.authorize_workflow_run_start(
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
service_key=service_key,
|
||||
require_correlation_id=requires_correlation,
|
||||
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by=(
|
||||
str(workflow_owner.provider_id)
|
||||
if workflow_owner.provider_id is not None
|
||||
else None
|
||||
),
|
||||
metadata={
|
||||
"dograh_user_id": str(workflow_owner.id),
|
||||
"workflow_id": workflow_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize workflow start with MPS for org {}: {}",
|
||||
organization_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
billing_mode = authorization.get("billing_mode")
|
||||
if billing_mode != "v2":
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
remaining = _safe_float(authorization.get("remaining_credits"))
|
||||
if (
|
||||
not authorization.get("allowed", False)
|
||||
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
|
||||
):
|
||||
logger.warning(
|
||||
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return _insufficient_billing_v2_quota_result(), True
|
||||
|
||||
try:
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
authorization.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to store MPS correlation id for workflow_run_id {}: {}",
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
logger.info(
|
||||
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return QuotaCheckResult(has_quota=True), True
|
||||
|
||||
|
||||
async def _authorize_legacy_dograh_keys(
|
||||
*,
|
||||
dograh_api_keys: set[str],
|
||||
organization_id: int | None,
|
||||
workflow_owner: UserModel,
|
||||
) -> QuotaCheckResult:
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=workflow_owner.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
if model_overrides:
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
return _insufficient_legacy_quota_result()
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
dograh_api_keys = set()
|
||||
|
||||
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.llm.api_key)
|
||||
|
||||
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.stt.api_key)
|
||||
|
||||
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
# Check quota for ALL Dograh keys
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < 0.10:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=(
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def _authorize_oss_managed_v2_correlation(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> QuotaCheckResult:
|
||||
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
service_key = get_dograh_service_api_key(user_config)
|
||||
if not service_key:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
response.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
actor_user: UserModel | None = None,
|
||||
) -> QuotaCheckResult:
|
||||
"""Authorize a workflow run before any billable call/text runtime starts.
|
||||
|
||||
The workflow organization is the billing subject for hosted v2. The workflow
|
||||
owner is used only to resolve the effective model configuration and legacy
|
||||
service-key metadata.
|
||||
"""
|
||||
try:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if not workflow:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
actor_org_id = getattr(actor_user, "selected_organization_id", None)
|
||||
if actor_org_id is not None and actor_org_id != workflow.organization_id:
|
||||
logger.warning(
|
||||
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
|
||||
actor_org_id,
|
||||
workflow_id,
|
||||
workflow.organization_id,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
|
||||
if not workflow_owner:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="user_not_found",
|
||||
error_message="User not found",
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_owner.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=workflow.workflow_configurations,
|
||||
)
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
|
||||
workflow_owner=workflow_owner,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
if hosted_enforced or not hosted_result.has_quota:
|
||||
return hosted_result
|
||||
|
||||
dograh_api_keys = _dograh_api_keys(user_config)
|
||||
if not dograh_api_keys:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
legacy_result = await _authorize_legacy_dograh_keys(
|
||||
dograh_api_keys=dograh_api_keys,
|
||||
organization_id=(
|
||||
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
|
||||
),
|
||||
workflow_owner=workflow_owner,
|
||||
)
|
||||
if not legacy_result.has_quota:
|
||||
return legacy_result
|
||||
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return await _authorize_oss_managed_v2_correlation(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
|
@ -129,30 +407,3 @@ async def check_dograh_quota(
|
|||
logger.error(f"Error during quota check: {str(e)}")
|
||||
# On unexpected error, allow the call to proceed
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def check_dograh_quota_by_user_id(
|
||||
user_id: int, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check Dograh quota by user ID.
|
||||
|
||||
Convenience function that fetches the user and then checks quota. When
|
||||
``workflow_id`` is provided, the workflow's ``model_overrides`` are
|
||||
applied so the quota check evaluates the credentials that will actually
|
||||
be used for the call.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check quota for
|
||||
workflow_id: Optional workflow whose per-workflow overrides should
|
||||
be applied to the user's config before checking quota.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with quota status
|
||||
"""
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message="User not found",
|
||||
)
|
||||
return await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
for run in runs:
|
||||
initial = run.initial_context or {}
|
||||
gathered = run.gathered_context or {}
|
||||
cost = run.cost_info or {}
|
||||
usage = run.usage_info or {}
|
||||
|
||||
call_tags = gathered.get("call_tags", [])
|
||||
if isinstance(call_tags, list):
|
||||
|
|
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
run.created_at.isoformat() if run.created_at else "",
|
||||
initial.get("phone_number", ""),
|
||||
gathered.get("mapped_call_disposition", ""),
|
||||
cost.get("call_duration_seconds", ""),
|
||||
usage.get("call_duration_seconds", ""),
|
||||
]
|
||||
|
||||
extracted = gathered.get("extracted_variables", {})
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from loguru import logger
|
|||
from api.constants import REDIS_URL
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunMode
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.transfer_event_protocol import (
|
||||
TransferEvent,
|
||||
|
|
@ -564,19 +564,7 @@ class ARIConnection:
|
|||
|
||||
user_id = workflow.user_id
|
||||
|
||||
# 3. Check quota (apply per-workflow model_overrides).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=inbound_workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 4. Create workflow run
|
||||
# 3. Create workflow run
|
||||
call_id = channel_id
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=f"ARI Inbound {caller_number}",
|
||||
|
|
@ -602,6 +590,20 @@ class ARIConnection:
|
|||
f"(caller={caller_number}, called={called_number})"
|
||||
)
|
||||
|
||||
# 4. Check quota after the run exists so hosted v2 can mint and
|
||||
# store the MPS correlation id before the pipeline starts.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=inbound_workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 5. Answer the inbound channel
|
||||
await self._answer_channel(channel_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,8 @@ async def handle_cloudonix_cdr(request: Request):
|
|||
return {"status": "error", "message": "Missing domain field"}
|
||||
|
||||
# Extract call_id to find workflow run
|
||||
call_id = cdr_data.get("session").get("token")
|
||||
session = cdr_data.get("session")
|
||||
call_id = session.get("token") if isinstance(session, dict) else None
|
||||
logger.info(f"Cloudonix CDR data for call id {call_id} - {cdr_data}")
|
||||
if not call_id:
|
||||
logger.warning("Cloudonix CDR missing call_id field")
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ provider registry — see ProviderSpec.router.
|
|||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Header, Request
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from starlette.responses import HTMLResponse
|
||||
|
|
@ -29,6 +28,30 @@ from api.utils.telephony_helper import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
async def _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url: str,
|
||||
callback_data: dict,
|
||||
headers: dict,
|
||||
raw_body: str,
|
||||
*,
|
||||
log_prefix: str,
|
||||
) -> None:
|
||||
"""Verify a Vobiz callback signature, failing closed.
|
||||
|
||||
Vobiz signs every callback, so a missing signature header is an invalid
|
||||
request — ``provider.verify_inbound_signature`` returns ``False`` for both
|
||||
missing and forged signatures. Reject with HTTP 403 (per Vobiz's
|
||||
callback-validation docs) so the caller never reaches status processing.
|
||||
"""
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url, callback_data, headers, raw_body
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(f"{log_prefix} Invalid or missing Vobiz callback signature")
|
||||
raise HTTPException(status_code=403, detail="Invalid webhook signature")
|
||||
|
||||
|
||||
@router.post("/vobiz-xml", include_in_schema=False)
|
||||
async def handle_vobiz_xml_webhook(
|
||||
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
|
||||
|
|
@ -65,8 +88,6 @@ async def handle_vobiz_xml_webhook(
|
|||
async def handle_vobiz_hangup_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_vobiz_signature: Optional[str] = Header(None),
|
||||
x_vobiz_timestamp: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Vobiz hangup callback (sent when call ends).
|
||||
|
||||
|
|
@ -75,82 +96,23 @@ async def handle_vobiz_hangup_callback(
|
|||
"""
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Logging all headers and body to understand what Vobiz actually sends
|
||||
all_headers = dict(request.headers)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
# Parse the callback data from the raw body so signed webhooks can verify
|
||||
# the exact bytes Vobiz sent without draining the request stream first.
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
||||
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz hangup callback - Body: {json.dumps(callback_data)}"
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Received Vobiz hangup callback {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
# Verify signature if Vobiz provided any supported signature header.
|
||||
has_vobiz_signature = any(
|
||||
header in all_headers
|
||||
for header in (
|
||||
"x-vobiz-signature-v3",
|
||||
"x-vobiz-signature-ma-v3",
|
||||
"x-vobiz-signature-v2",
|
||||
"x-vobiz-signature-ma-v2",
|
||||
)
|
||||
)
|
||||
if has_vobiz_signature:
|
||||
# We need the workflow run to get organization for provider credentials
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_run_not_found"}
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_not_found"}
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Vobiz hangup callback signature"
|
||||
)
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
|
||||
logger.info(f"[run {workflow_run_id}] Vobiz hangup callback signature verified")
|
||||
else:
|
||||
# Get workflow run for processing (signature verification already got it if needed)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for Vobiz hangup callback"
|
||||
)
|
||||
return {"status": "ignored", "reason": "workflow_run_not_found"}
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(f"[run {workflow_run_id}] Workflow not found")
|
||||
|
|
@ -160,6 +122,21 @@ async def handle_vobiz_hangup_callback(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
|
||||
# before they can mutate call state.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
)
|
||||
await _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
log_prefix=f"[run {workflow_run_id}]",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[run {workflow_run_id}] Processing Vobiz hangup with provider: {provider.PROVIDER_NAME}"
|
||||
)
|
||||
|
|
@ -167,10 +144,6 @@ async def handle_vobiz_hangup_callback(
|
|||
# Parse the callback data into generic format
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
|
||||
logger.debug(
|
||||
f"[run {workflow_run_id}] Parsed Vobiz callback data: {json.dumps(parsed_data)}"
|
||||
)
|
||||
|
||||
# Create StatusCallbackRequest from parsed data
|
||||
status_update = StatusCallbackRequest(
|
||||
call_id=parsed_data["call_id"],
|
||||
|
|
@ -194,8 +167,6 @@ async def handle_vobiz_hangup_callback(
|
|||
async def handle_vobiz_ring_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_vobiz_signature: Optional[str] = Header(None),
|
||||
x_vobiz_timestamp: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Vobiz ring callback (sent when call starts ringing).
|
||||
|
||||
|
|
@ -204,84 +175,46 @@ async def handle_vobiz_ring_callback(
|
|||
"""
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Logging all headers and body to understand what Vobiz actually sends
|
||||
all_headers = dict(request.headers)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz ring callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
# Parse the callback data from the raw body so signed webhooks can verify
|
||||
# the exact bytes Vobiz sent without draining the request stream first.
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
||||
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz ring callback - Body: {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Received Vobiz ring callback {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
# Verify signature if Vobiz provided any supported signature header.
|
||||
has_vobiz_signature = any(
|
||||
header in all_headers
|
||||
for header in (
|
||||
"x-vobiz-signature-v3",
|
||||
"x-vobiz-signature-ma-v3",
|
||||
"x-vobiz-signature-v2",
|
||||
"x-vobiz-signature-ma-v2",
|
||||
)
|
||||
)
|
||||
if has_vobiz_signature:
|
||||
# We need the workflow run to get organization for provider credentials
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_run_not_found"}
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_not_found"}
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
)
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Vobiz ring callback signature"
|
||||
)
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
|
||||
logger.info(f"[run {workflow_run_id}] Vobiz ring callback signature verified")
|
||||
else:
|
||||
# Get workflow run for processing (signature verification already got it if needed)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for Vobiz ring callback"
|
||||
)
|
||||
return {"status": "ignored", "reason": "workflow_run_not_found"}
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(f"[run {workflow_run_id}] Workflow not found")
|
||||
return {"status": "ignored", "reason": "workflow_not_found"}
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Fail closed: reject unsigned/forged ring callbacks before logging them.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
)
|
||||
await _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
log_prefix=f"[run {workflow_run_id}]",
|
||||
)
|
||||
|
||||
# Log the ringing event
|
||||
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
|
||||
ring_log = {
|
||||
|
|
@ -308,15 +241,10 @@ async def handle_vobiz_ring_callback(
|
|||
async def handle_vobiz_hangup_callback_by_workflow(
|
||||
workflow_id: int,
|
||||
request: Request,
|
||||
x_vobiz_signature: Optional[str] = Header(None),
|
||||
x_vobiz_timestamp: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Vobiz hangup callback with workflow_id - finds workflow run by call_id."""
|
||||
|
||||
all_headers = dict(request.headers)
|
||||
logger.info(
|
||||
f"[workflow {workflow_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
try:
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
|
@ -364,35 +292,18 @@ async def handle_vobiz_hangup_callback_by_workflow(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
has_vobiz_signature = any(
|
||||
header in all_headers
|
||||
for header in (
|
||||
"x-vobiz-signature-v3",
|
||||
"x-vobiz-signature-ma-v3",
|
||||
"x-vobiz-signature-v2",
|
||||
"x-vobiz-signature-ma-v2",
|
||||
)
|
||||
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
|
||||
# before they can mutate call state.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
await _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
log_prefix=f"[workflow {workflow_id}]",
|
||||
)
|
||||
if has_vobiz_signature:
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[workflow {workflow_id}] Invalid Vobiz hangup callback signature"
|
||||
)
|
||||
return {"status": "error", "message": "invalid_signature"}
|
||||
|
||||
logger.info(
|
||||
f"[workflow {workflow_id}] Vobiz hangup callback signature verified"
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
|
|
|
|||
|
|
@ -66,34 +66,6 @@ async def handle_vonage_events(
|
|||
logger.error(f"[run {workflow_run_id}] Workflow run not found")
|
||||
return {"status": "error", "message": "Workflow run not found"}
|
||||
|
||||
# For a completed call that includes cost info, capture it immediately
|
||||
if event_data.get("status") == "completed":
|
||||
# Vonage sometimes includes price info in the webhook
|
||||
if "price" in event_data or "rate" in event_data:
|
||||
try:
|
||||
if workflow_run.cost_info:
|
||||
# Store immediate cost info if available
|
||||
cost_info = workflow_run.cost_info.copy()
|
||||
if "price" in event_data:
|
||||
cost_info["vonage_webhook_price"] = float(event_data["price"])
|
||||
if "rate" in event_data:
|
||||
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
|
||||
if "duration" in event_data:
|
||||
cost_info["vonage_webhook_duration"] = int(
|
||||
event_data["duration"]
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, cost_info=cost_info
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
|
||||
)
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
|
|
|
|||
|
|
@ -114,11 +114,13 @@ class StatusCallbackRequest(BaseModel):
|
|||
"NOANSWER": "no-answer",
|
||||
}
|
||||
|
||||
disposition = data.get("disposition", "")
|
||||
disposition = data.get("disposition") or ""
|
||||
status = disposition_map.get(disposition.upper(), disposition.lower())
|
||||
session = data.get("session")
|
||||
call_id = session.get("token") if isinstance(session, dict) else ""
|
||||
|
||||
return cls(
|
||||
call_id=data.get("session").get("token"),
|
||||
call_id=call_id or "",
|
||||
status=status,
|
||||
from_number=data.get("from"),
|
||||
to_number=data.get("to"),
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import asyncio
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
|
|
@ -382,6 +383,9 @@ class PipecatEngine:
|
|||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
correlation_id=self._call_context_vars.get(
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
),
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
|
@ -43,7 +42,7 @@ async def resolve_llm_config(
|
|||
async def resolve_user_llm_config(
|
||||
workflow_run: WorkflowRunModel,
|
||||
) -> tuple[str, str, str, dict]:
|
||||
"""Resolve the user's configured LLM (from UserConfiguration).
|
||||
"""Resolve the user's configured LLM (from EffectiveAIModelConfiguration).
|
||||
|
||||
Returns:
|
||||
(provider, model, api_key, service_kwargs) tuple
|
||||
|
|
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
|
|||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_configurations = {}
|
||||
if workflow_run.definition:
|
||||
workflow_configurations = (
|
||||
workflow_run.definition.workflow_configurations or {}
|
||||
)
|
||||
elif workflow_run.workflow:
|
||||
workflow_configurations = (
|
||||
workflow_run.workflow.workflow_configurations or {}
|
||||
)
|
||||
|
||||
user_configuration = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow_run.workflow.organization_id
|
||||
if workflow_run.workflow
|
||||
else None,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
|
|
|
|||
41
api/services/workflow/run_usage_response.py
Normal file
41
api/services/workflow/run_usage_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
||||
|
||||
def format_public_cost_info(
|
||||
cost_info: dict | None, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
"""Return the legacy response shape without doing local cost accounting."""
|
||||
duration = None
|
||||
if usage_info and usage_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(usage_info.get("call_duration_seconds") or 0))
|
||||
elif cost_info and cost_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(cost_info.get("call_duration_seconds") or 0))
|
||||
|
||||
dograh_token_usage = 0
|
||||
if cost_info:
|
||||
if "dograh_token_usage" in cost_info:
|
||||
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
|
||||
elif "total_cost_usd" in cost_info:
|
||||
dograh_token_usage = round(
|
||||
float(cost_info.get("total_cost_usd", 0)) * 100, 2
|
||||
)
|
||||
|
||||
if duration is None and dograh_token_usage == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"dograh_token_usage": dograh_token_usage,
|
||||
"call_duration_seconds": duration,
|
||||
}
|
||||
|
|
@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
|
|
@ -410,14 +409,31 @@ async def execute_text_chat_pending_turn(
|
|||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_run.workflow.user.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
base_initial_context = dict(workflow_run.initial_context or {})
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=base_initial_context,
|
||||
)
|
||||
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
|
|
@ -425,9 +441,15 @@ async def execute_text_chat_pending_turn(
|
|||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
**base_initial_context,
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
if mps_correlation_id:
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
|
|
@ -466,9 +488,17 @@ async def execute_text_chat_pending_turn(
|
|||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
|
|
@ -606,8 +636,10 @@ async def execute_text_chat_pending_turn(
|
|||
"Transportless text chat pipeline failed while closing run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
raise
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
|
|
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
|
|||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ async def _perform_retrieval(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -272,11 +278,20 @@ async def _perform_retrieval(
|
|||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
111
api/services/workflow_run_billing.py
Normal file
111
api/services/workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Workflow-run billing hooks.
|
||||
|
||||
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
|
||||
For hosted deployments, Dograh reports completed platform usage to MPS.
|
||||
When a server-minted MPS correlation id exists, MPS uses model-service usage
|
||||
as the canonical duration. Otherwise Dograh reports the completed run duration.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.services.managed_model_services import get_mps_correlation_id
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
def _workflow_run_organization_id(workflow_run) -> int | None:
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
return getattr(workflow, "organization_id", None)
|
||||
|
||||
|
||||
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
|
||||
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
|
||||
duration = usage_info.get("call_duration_seconds")
|
||||
try:
|
||||
duration_seconds = float(duration)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return duration_seconds if duration_seconds > 0 else None
|
||||
|
||||
|
||||
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
|
||||
account = await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id
|
||||
)
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def report_workflow_run_platform_usage(workflow_run) -> None:
|
||||
"""Report hosted platform usage for a completed workflow run to MPS."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return
|
||||
|
||||
if not getattr(workflow_run, "is_completed", False):
|
||||
return
|
||||
|
||||
organization_id = _workflow_run_organization_id(workflow_run)
|
||||
if organization_id is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no organization_id",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
correlation_id = get_mps_correlation_id(
|
||||
getattr(workflow_run, "initial_context", None)
|
||||
)
|
||||
duration_seconds = (
|
||||
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
|
||||
)
|
||||
if not correlation_id and duration_seconds is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no billable duration",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if not await _organization_uses_mps_billing_v2(organization_id):
|
||||
return
|
||||
|
||||
result = await mps_service_key_client.report_platform_usage(
|
||||
organization_id=organization_id,
|
||||
correlation_id=correlation_id,
|
||||
duration_seconds=duration_seconds,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": getattr(workflow_run, "workflow_id", None),
|
||||
"duration_source": (
|
||||
"mps_correlation" if correlation_id else "dograh_usage_info"
|
||||
),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Reported platform usage for workflow run {} to MPS: {}",
|
||||
workflow_run.id,
|
||||
result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
|
||||
"""Load a completed workflow run and report platform usage to MPS."""
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Skipping platform usage report: workflow run {} not found",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
|
@ -45,10 +45,8 @@ from api.tasks.campaign_tasks import (
|
|||
)
|
||||
from api.tasks.knowledge_base_processing import process_knowledge_base_document
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.tasks.s3_upload import (
|
||||
process_workflow_completion,
|
||||
upload_voicemail_audio_to_s3,
|
||||
)
|
||||
from api.tasks.s3_upload import upload_voicemail_audio_to_s3
|
||||
from api.tasks.workflow_completion import process_workflow_completion
|
||||
|
||||
|
||||
class WorkerSettings:
|
||||
|
|
|
|||
|
|
@ -157,15 +157,31 @@ async def process_knowledge_base_document(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(document.created_by)
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=document.created_by,
|
||||
organization_id=document.organization_id,
|
||||
)
|
||||
effective_config = resolved_config.effective
|
||||
if effective_config.embeddings:
|
||||
embeddings_provider = getattr(
|
||||
effective_config.embeddings, "provider", None
|
||||
)
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(
|
||||
effective_config.embeddings, "endpoint", None
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
logger.info(
|
||||
f"Using user embeddings config: provider={embeddings_provider}, "
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.pricing.workflow_run_cost import calculate_workflow_run_cost
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
||||
async def upload_voicemail_audio_to_s3(
|
||||
|
|
@ -69,110 +65,3 @@ async def upload_voicemail_audio_to_s3(
|
|||
logger.warning(
|
||||
f"Failed to clean up temp voicemail audio file {temp_file_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
This task combines audio upload, transcript upload, and webhook integrations
|
||||
into a single sequential task to ensure integrations run after uploads complete.
|
||||
|
||||
Args:
|
||||
_ctx: ARQ context (unused)
|
||||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Processing workflow completion for run {workflow_run_id}")
|
||||
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
try:
|
||||
if os.path.exists(transcript_temp_path):
|
||||
file_size = os.path.getsize(transcript_temp_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
logger.info(
|
||||
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded transcript: {transcript_url}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcript temp file not found: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if transcript_temp_path and os.path.exists(transcript_temp_path):
|
||||
try:
|
||||
os.remove(transcript_temp_path)
|
||||
logger.debug(
|
||||
f"Cleaned up temp transcript file: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Calculate cost after integrations (so QA token usage is included)
|
||||
try:
|
||||
await calculate_workflow_run_cost(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow {workflow_run_id}: {e}")
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
|
|||
121
api/tasks/workflow_completion.py
Normal file
121
api/tasks/workflow_completion.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_current_storage_backend, storage_fs
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
)
|
||||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
This task combines audio upload, transcript upload, and webhook integrations
|
||||
into a single sequential task to ensure integrations run after uploads complete.
|
||||
|
||||
Args:
|
||||
_ctx: ARQ context (unused)
|
||||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
||||
logger.info(f"Processing workflow completion for run {workflow_run_id}")
|
||||
|
||||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
try:
|
||||
if os.path.exists(transcript_temp_path):
|
||||
file_size = os.path.getsize(transcript_temp_path)
|
||||
logger.debug(f"Transcript file size: {file_size} bytes")
|
||||
|
||||
transcript_url = f"transcripts/{workflow_run_id}.txt"
|
||||
logger.info(
|
||||
f"Uploading transcript to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(transcript_temp_path, transcript_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
transcript_url=transcript_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded transcript: {transcript_url}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Transcript temp file not found: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading transcript for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if transcript_temp_path and os.path.exists(transcript_temp_path):
|
||||
try:
|
||||
os.remove(transcript_temp_path)
|
||||
logger.debug(
|
||||
f"Cleaned up temp transcript file: {transcript_temp_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp transcript file: {e}")
|
||||
|
||||
# Step 3: Run integrations including QA analysis (after uploads are complete)
|
||||
try:
|
||||
await run_integrations_post_workflow_run(_ctx, workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow {workflow_run_id}: {e}")
|
||||
|
||||
# Step 4: Notify MPS after completion. MPS owns credit accounting.
|
||||
try:
|
||||
await report_completed_workflow_run_platform_usage(workflow_run_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error reporting platform usage for workflow {workflow_run_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Completed workflow completion processing for run {workflow_run_id}")
|
||||
|
|
@ -203,7 +203,7 @@ async def create_workflow_run_rows(
|
|||
Returns:
|
||||
Tuple of (workflow_run, user, workflow).
|
||||
"""
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
org = OrganizationModel(provider_id=f"test-org-{provider_id_suffix}")
|
||||
async_session.add(org)
|
||||
|
|
@ -218,7 +218,7 @@ async def create_workflow_run_rows(
|
|||
|
||||
await db_session.update_user_configuration(
|
||||
user_id=user.id,
|
||||
configuration=UserConfiguration.model_validate(USER_CONFIGURATION),
|
||||
configuration=EffectiveAIModelConfiguration.model_validate(USER_CONFIGURATION),
|
||||
)
|
||||
|
||||
workflow = await db_session.create_workflow(
|
||||
|
|
|
|||
119
api/tests/telephony/cloudonix/test_routes.py
Normal file
119
api/tests/telephony/cloudonix/test_routes.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""Regression tests for Cloudonix CDR webhook handling.
|
||||
|
||||
A Cloudonix CDR webhook is a public, unauthenticated endpoint that parses
|
||||
arbitrary external JSON. A partial / malformed payload (missing ``session``,
|
||||
or a ``null`` ``session`` / ``disposition``) must produce a graceful error
|
||||
response, not an unhandled ``AttributeError`` (HTTP 500).
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.services.telephony.providers.cloudonix.routes import handle_cloudonix_cdr
|
||||
from api.services.telephony.status_processor import StatusCallbackRequest
|
||||
|
||||
|
||||
def _json_request(body: bytes) -> Request:
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
return Request(
|
||||
{
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"scheme": "https",
|
||||
"server": ("example.test", 443),
|
||||
"path": "/api/v1/telephony/cloudonix/cdr",
|
||||
"query_string": b"",
|
||||
"headers": [(b"content-type", b"application/json")],
|
||||
},
|
||||
receive,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cdr_route_handles_payload_without_session():
|
||||
"""A CDR payload missing the ``session`` object returns a graceful error
|
||||
instead of raising ``AttributeError`` on ``None.get("token")``."""
|
||||
request = _json_request(b'{"domain": "acme.cloudonix.io", "disposition": "ANSWER"}')
|
||||
|
||||
with patch(
|
||||
"api.services.telephony.providers.cloudonix.routes.db_client"
|
||||
) as db_client:
|
||||
db_client.get_workflow_run_by_call_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await handle_cloudonix_cdr(request)
|
||||
|
||||
assert result == {"status": "error", "message": "Missing call_id field"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cdr_route_handles_null_session():
|
||||
"""A CDR payload with an explicit ``null`` session is handled gracefully."""
|
||||
request = _json_request(b'{"domain": "acme.cloudonix.io", "session": null}')
|
||||
|
||||
with patch(
|
||||
"api.services.telephony.providers.cloudonix.routes.db_client"
|
||||
) as db_client:
|
||||
db_client.get_workflow_run_by_call_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await handle_cloudonix_cdr(request)
|
||||
|
||||
assert result == {"status": "error", "message": "Missing call_id field"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cdr_route_handles_string_session():
|
||||
"""A CDR payload with a non-object session is handled gracefully."""
|
||||
request = _json_request(b'{"domain": "acme.cloudonix.io", "session": "abc"}')
|
||||
|
||||
with patch(
|
||||
"api.services.telephony.providers.cloudonix.routes.db_client"
|
||||
) as db_client:
|
||||
db_client.get_workflow_run_by_call_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await handle_cloudonix_cdr(request)
|
||||
|
||||
assert result == {"status": "error", "message": "Missing call_id field"}
|
||||
|
||||
|
||||
def test_from_cloudonix_cdr_tolerates_missing_session_and_disposition():
|
||||
"""``from_cloudonix_cdr`` must not crash on a partial CDR payload."""
|
||||
# Missing both session and disposition.
|
||||
req = StatusCallbackRequest.from_cloudonix_cdr({"domain": "acme.cloudonix.io"})
|
||||
assert req.call_id == ""
|
||||
assert req.status == ""
|
||||
|
||||
# Explicit null values.
|
||||
req = StatusCallbackRequest.from_cloudonix_cdr(
|
||||
{"session": None, "disposition": None}
|
||||
)
|
||||
assert req.call_id == ""
|
||||
assert req.status == ""
|
||||
|
||||
|
||||
def test_from_cloudonix_cdr_tolerates_string_session():
|
||||
"""``from_cloudonix_cdr`` treats a non-object session as missing call_id."""
|
||||
req = StatusCallbackRequest.from_cloudonix_cdr(
|
||||
{"session": "abc", "disposition": "ANSWER"}
|
||||
)
|
||||
assert req.call_id == ""
|
||||
assert req.status == "completed"
|
||||
|
||||
|
||||
def test_from_cloudonix_cdr_maps_disposition_and_session_token():
|
||||
"""Normal, well-formed CDR payloads still map correctly."""
|
||||
req = StatusCallbackRequest.from_cloudonix_cdr(
|
||||
{
|
||||
"session": {"token": "abc123"},
|
||||
"disposition": "BUSY",
|
||||
"from": "+15551230001",
|
||||
"to": "+15551230002",
|
||||
"billsec": 12,
|
||||
}
|
||||
)
|
||||
assert req.call_id == "abc123"
|
||||
assert req.status == "busy"
|
||||
assert req.duration == "12"
|
||||
|
|
@ -6,11 +6,13 @@ from unittest.mock import AsyncMock, patch
|
|||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.services.telephony.providers.vobiz.provider import VobizProvider
|
||||
from api.services.telephony.providers.vobiz.routes import (
|
||||
handle_vobiz_hangup_callback,
|
||||
handle_vobiz_hangup_callback_by_workflow,
|
||||
handle_vobiz_ring_callback,
|
||||
)
|
||||
|
||||
|
|
@ -225,3 +227,154 @@ async def test_vobiz_verify_inbound_signature_rejects_missing_signature():
|
|||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vobiz_hangup_callback_rejects_missing_signature():
|
||||
"""An unsigned hangup callback must be rejected before status processing."""
|
||||
provider = _provider()
|
||||
form_data = {
|
||||
"CallUUID": "call-123",
|
||||
"CallStatus": "completed",
|
||||
"From": "15551230001",
|
||||
"To": "15551230002",
|
||||
"Direction": "outbound",
|
||||
"Duration": "12",
|
||||
}
|
||||
# No x-vobiz-signature-* headers — the callback is unsigned.
|
||||
request = _request(
|
||||
path="/api/v1/telephony/vobiz/hangup-callback/123",
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("https://example.test", "wss://example.test"),
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes._process_status_update",
|
||||
new_callable=AsyncMock,
|
||||
) as process_status,
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(workflow_id=7)
|
||||
)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handle_vobiz_hangup_callback(
|
||||
workflow_run_id=123,
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
process_status.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vobiz_ring_callback_rejects_missing_signature():
|
||||
"""An unsigned ring callback must be rejected before it is logged."""
|
||||
provider = _provider()
|
||||
form_data = {
|
||||
"CallUUID": "call-123",
|
||||
"CallStatus": "ringing",
|
||||
"From": "15551230001",
|
||||
"To": "15551230002",
|
||||
}
|
||||
# No x-vobiz-signature-* headers — the callback is unsigned.
|
||||
request = _request(
|
||||
path="/api/v1/telephony/vobiz/ring-callback/123",
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
workflow_run = SimpleNamespace(workflow_id=7, logs={})
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("https://example.test", "wss://example.test"),
|
||||
),
|
||||
):
|
||||
db_client.get_workflow_run_by_id = AsyncMock(return_value=workflow_run)
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
db_client.update_workflow_run = AsyncMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handle_vobiz_ring_callback(
|
||||
workflow_run_id=123,
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
db_client.update_workflow_run.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vobiz_hangup_callback_by_workflow_rejects_missing_signature():
|
||||
"""An unsigned by-workflow hangup callback must be rejected before processing."""
|
||||
provider = _provider()
|
||||
form_data = {
|
||||
"CallUUID": "call-123",
|
||||
"CallStatus": "completed",
|
||||
"From": "15551230001",
|
||||
"To": "15551230002",
|
||||
"Direction": "outbound",
|
||||
"Duration": "12",
|
||||
}
|
||||
# No x-vobiz-signature-* headers — the callback is unsigned.
|
||||
request = _request(
|
||||
path="/api/v1/telephony/vobiz/hangup-callback/workflow/7",
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.services.telephony.providers.vobiz.routes.db_client") as db_client,
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_telephony_provider_for_run",
|
||||
new_callable=AsyncMock,
|
||||
return_value=provider,
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes.get_backend_endpoints",
|
||||
new_callable=AsyncMock,
|
||||
return_value=("https://example.test", "wss://example.test"),
|
||||
),
|
||||
patch(
|
||||
"api.services.telephony.providers.vobiz.routes._process_status_update",
|
||||
new_callable=AsyncMock,
|
||||
) as process_status,
|
||||
):
|
||||
db_client.get_workflow_by_id = AsyncMock(
|
||||
return_value=SimpleNamespace(organization_id=11)
|
||||
)
|
||||
db_client.get_workflow_run_by_call_id = AsyncMock(
|
||||
return_value=SimpleNamespace(id=123, workflow_id=7)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handle_vobiz_hangup_callback_by_workflow(
|
||||
workflow_id=7,
|
||||
request=request,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
process_status.assert_not_awaited()
|
||||
|
|
|
|||
459
api/tests/test_ai_model_configuration_v2.py
Normal file
459
api/tests/test_ai_model_configuration_v2.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
GoogleLLMService,
|
||||
GoogleRealtimeLLMConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
|
||||
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
voice="default",
|
||||
speed=1.2,
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is False
|
||||
assert effective.llm.provider == "dograh"
|
||||
assert effective.llm.model == "default"
|
||||
assert effective.tts.provider == "dograh"
|
||||
assert effective.tts.speed == 1.2
|
||||
assert effective.stt.provider == "dograh"
|
||||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
assert effective.managed_service_version == 2
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
speed=1.5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_byok_v2_rejects_dograh_provider():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_realtime_validator_does_not_require_stt_or_tts():
|
||||
config = OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "realtime",
|
||||
"realtime": {
|
||||
"realtime": {
|
||||
"provider": "google_realtime",
|
||||
"api_key": "google-realtime-key",
|
||||
"model": "gemini-3.1-flash-live-preview",
|
||||
"voice": "Puck",
|
||||
"language": "en",
|
||||
},
|
||||
"llm": {
|
||||
"provider": "google",
|
||||
"api_key": "google-llm-key",
|
||||
"model": "gemini-2.0-flash",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is True
|
||||
assert effective.stt is None
|
||||
assert effective.tts is None
|
||||
assert await UserConfigurationValidator().validate(effective) == {
|
||||
"status": [{"model": "all", "message": "ok"}]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_validator_requires_stt_and_tts_when_not_realtime():
|
||||
effective = EffectiveAIModelConfiguration(
|
||||
llm=GoogleLLMService(
|
||||
provider="google",
|
||||
api_key="google-llm-key",
|
||||
model="gemini-2.0-flash",
|
||||
),
|
||||
realtime=GoogleRealtimeLLMConfiguration(
|
||||
provider="google_realtime",
|
||||
api_key="google-realtime-key",
|
||||
model="gemini-3.1-flash-live-preview",
|
||||
voice="Puck",
|
||||
language="en",
|
||||
),
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await UserConfigurationValidator().validate(effective)
|
||||
|
||||
assert exc_info.value.args[0] == [
|
||||
{"model": "stt", "message": "API key is missing"},
|
||||
{"model": "tts", "message": "API key is missing"},
|
||||
]
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
|
||||
)
|
||||
incoming = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
|
||||
)
|
||||
|
||||
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
|
||||
|
||||
assert merged.dograh.api_key == "mps-real-secret"
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(merged)
|
||||
|
||||
|
||||
def test_masked_v2_configuration_masks_nested_service_keys():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok={
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"api_key": "sk-real-secret",
|
||||
"model": "gpt-4.1",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "elevenlabs",
|
||||
"api_key": "el-real-secret",
|
||||
"model": "eleven_flash_v2_5",
|
||||
"voice": "Rachel",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "deepgram",
|
||||
"api_key": "dg-real-secret",
|
||||
"model": "nova-3-general",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
masked = mask_ai_model_configuration_v2(config)
|
||||
|
||||
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
|
||||
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
|
||||
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
|
||||
|
||||
|
||||
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
speed=1.0,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-secret"
|
||||
|
||||
|
||||
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key="mps-tts",
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key="mps-stt",
|
||||
model="default",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-tts"
|
||||
assert config.dograh.voice == "default"
|
||||
|
||||
|
||||
def test_legacy_byok_pipeline_converts_to_byok_v2():
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "byok"
|
||||
assert config.byok.mode == "pipeline"
|
||||
assert config.byok.pipeline.llm.provider == "openai"
|
||||
assert config.byok.pipeline.tts.provider == "elevenlabs"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
|
||||
base = EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
)
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": {
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-workflow",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
|
||||
assert v2_override["mode"] == "dograh"
|
||||
assert v2_override["dograh"]["api_key"] == "mps-workflow"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
||||
base = EffectiveAIModelConfiguration()
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": None,
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_model_configuration_v2_initializes_hosted_mps_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
from api.routes import organization as organization_routes
|
||||
|
||||
legacy = EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
)
|
||||
expected_response = OrganizationAIModelConfigurationResponse(
|
||||
configuration={"version": 2, "mode": "dograh"},
|
||||
effective_configuration={},
|
||||
source="organization_v2",
|
||||
)
|
||||
|
||||
class FakeValidator:
|
||||
async def validate(self, *args, **kwargs):
|
||||
return {"status": [{"model": "all", "message": "ok"}]}
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
upsert = AsyncMock()
|
||||
migrate_workflows = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(organization_routes, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"get_organization_ai_model_configuration_v2",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=legacy),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"UserConfigurationValidator",
|
||||
lambda: FakeValidator(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"upsert_organization_ai_model_configuration_v2",
|
||||
upsert,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"migrate_workflow_model_configurations_to_v2",
|
||||
migrate_workflows,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_routes,
|
||||
"_model_configuration_v2_response",
|
||||
AsyncMock(return_value=expected_response),
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_routes.migrate_model_configuration_v2(
|
||||
force=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="provider-123")
|
||||
upsert.assert_awaited_once()
|
||||
migrate_workflows.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
assert response == expected_response
|
||||
68
api/tests/test_auth_depends.py
Normal file
68
api/tests/test_auth_depends.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services.auth import depends as auth_depends
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_initializes_hosted_mps_billing_for_new_org(monkeypatch):
|
||||
stack_user = {
|
||||
"id": "stack-user-1",
|
||||
"selected_team_id": "team-1",
|
||||
"primary_email_verified": False,
|
||||
}
|
||||
user = SimpleNamespace(
|
||||
id=7,
|
||||
email=None,
|
||||
provider_id="stack-user-1",
|
||||
selected_organization_id=None,
|
||||
)
|
||||
organization = SimpleNamespace(id=42)
|
||||
existing_config = SimpleNamespace(llm=object(), tts=None, stt=None)
|
||||
|
||||
ensure_billing = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(auth_depends, "AUTH_PROVIDER", "stack")
|
||||
monkeypatch.setattr(
|
||||
auth_depends.stackauth,
|
||||
"get_user",
|
||||
AsyncMock(return_value=stack_user),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_user_by_provider_id",
|
||||
AsyncMock(return_value=(user, False)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_or_create_organization_by_provider_id",
|
||||
AsyncMock(return_value=(organization, True)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"add_user_to_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"update_user_selected_organization",
|
||||
AsyncMock(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends.db_client,
|
||||
"get_user_configurations",
|
||||
AsyncMock(return_value=existing_config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_depends,
|
||||
"ensure_hosted_mps_billing_account_v2",
|
||||
ensure_billing,
|
||||
)
|
||||
|
||||
result = await auth_depends.get_user(authorization="Bearer token")
|
||||
|
||||
assert result is user
|
||||
assert result.selected_organization_id == 42
|
||||
ensure_billing.assert_awaited_once_with(42, created_by="stack-user-1")
|
||||
45
api/tests/test_cartesia_tts_service_factory.py
Normal file
45
api/tests/test_cartesia_tts_service_factory.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.registry import (
|
||||
CARTESIA_TTS_MODELS,
|
||||
CartesiaTTSConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.service_factory import create_tts_service
|
||||
|
||||
|
||||
def test_cartesia_tts_configuration_defaults_to_sonic_3_5():
|
||||
config = CartesiaTTSConfiguration(api_key="test-key")
|
||||
|
||||
assert config.provider == ServiceProviders.CARTESIA
|
||||
assert config.model == "sonic-3.5"
|
||||
assert CARTESIA_TTS_MODELS == ["sonic-3.5", "sonic-3"]
|
||||
|
||||
|
||||
def test_create_cartesia_tts_service_passes_selected_model():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.CARTESIA.value,
|
||||
api_key="test-key",
|
||||
model="sonic-3.5",
|
||||
voice="test-voice-id",
|
||||
speed=1.0,
|
||||
volume=1.0,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(
|
||||
transport_out_sample_rate=24000,
|
||||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.CartesiaTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["settings"].model == "sonic-3.5"
|
||||
assert kwargs["settings"].voice == "test-voice-id"
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
|
||||
|
||||
def test_cost_calculator():
|
||||
"""Test function to verify cost calculation works"""
|
||||
sample_usage = {
|
||||
"llm": {
|
||||
"OpenAILLMService#0|||gpt-4.1-mini": {
|
||||
"prompt_tokens": 45380,
|
||||
"completion_tokens": 496,
|
||||
"total_tokens": 45876,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
},
|
||||
"tts": {"ElevenLabsTTSService#0|||eleven_flash_v2_5": 2399},
|
||||
"stt": {"DeepgramSTTService#0|||nova-3-general": 177.21536946296692},
|
||||
"call_duration_seconds": 179,
|
||||
}
|
||||
|
||||
result = cost_calculator.calculate_total_cost(sample_usage)
|
||||
assert result["llm_cost"] == 45380 * 0.40 / 1_000_000 + 496 * 1.60 / 1_000_000
|
||||
assert result["tts_cost"] == 2399 * 0.0256 / 1_000
|
||||
assert result["stt_cost"] == 177.21536946296692 / 60 * 0.0077
|
||||
assert (
|
||||
abs(
|
||||
result["total"]
|
||||
- (result["llm_cost"] + result["tts_cost"] + result["stt_cost"])
|
||||
)
|
||||
< 1e-10
|
||||
)
|
||||
110
api/tests/test_dograh_managed_correlation.py
Normal file
110
api/tests/test_dograh_managed_correlation.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai._types import NOT_GIVEN as OPENAI_NOT_GIVEN
|
||||
from pipecat.frames.frames import TTSStartedFrame
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from websockets.protocol import State
|
||||
|
||||
|
||||
class _FakeWebSocket:
|
||||
def __init__(self):
|
||||
self.state = State.OPEN
|
||||
self.messages: list[dict] = []
|
||||
|
||||
async def send(self, message: str) -> None:
|
||||
self.messages.append(json.loads(message))
|
||||
|
||||
async def close(self, *args, **kwargs) -> None:
|
||||
self.state = State.CLOSED
|
||||
|
||||
|
||||
def test_dograh_llm_uses_explicit_mps_correlation_id():
|
||||
service = DograhLLMService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
settings=OpenAILLMSettings(model="default"),
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
params = service.build_chat_completion_params(
|
||||
{
|
||||
"messages": [],
|
||||
"tools": OPENAI_NOT_GIVEN,
|
||||
"tool_choice": OPENAI_NOT_GIVEN,
|
||||
}
|
||||
)
|
||||
|
||||
assert params["metadata"]["correlation_id"] == "mps-corr-123"
|
||||
assert params["metadata"]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_stt_config_uses_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.stt.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhSTTService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=16000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dograh_tts_messages_use_explicit_mps_correlation_id(monkeypatch):
|
||||
fake_ws = _FakeWebSocket()
|
||||
|
||||
async def fake_connect(url, additional_headers):
|
||||
return fake_ws
|
||||
|
||||
monkeypatch.setattr(
|
||||
"pipecat.services.dograh.tts.websocket_connect",
|
||||
fake_connect,
|
||||
)
|
||||
|
||||
service = DograhTTSService(
|
||||
api_key="mps-secret",
|
||||
correlation_id="mps-corr-123",
|
||||
sample_rate=24000,
|
||||
)
|
||||
service._start_metadata = {"workflow_run_id": 99}
|
||||
|
||||
await service._connect_websocket()
|
||||
assert fake_ws.messages[0]["type"] == "config"
|
||||
assert fake_ws.messages[0]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[0]["mps_billing_version"] == "2"
|
||||
|
||||
async def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service.audio_context_available = lambda context_id: False
|
||||
service.create_audio_context = _noop
|
||||
service.start_ttfb_metrics = _noop
|
||||
service.start_tts_usage_metrics = _noop
|
||||
|
||||
frames = []
|
||||
async for frame in service.run_tts("hello", "ctx-1"):
|
||||
frames.append(frame)
|
||||
|
||||
assert isinstance(frames[0], TTSStartedFrame)
|
||||
assert fake_ws.messages[1]["type"] == "create_context"
|
||||
assert fake_ws.messages[1]["correlation_id"] == "mps-corr-123"
|
||||
assert fake_ws.messages[1]["mps_billing_version"] == "2"
|
||||
|
|
@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig:
|
|||
"api.services.campaign.campaign_call_dispatcher.get_backend_endpoints",
|
||||
AsyncMock(return_value=("https://example.com", None)),
|
||||
),
|
||||
patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start",
|
||||
AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1))
|
||||
mock_db.create_workflow_run = AsyncMock(return_value=workflow_run)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
|||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.xai.realtime import events
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import GrokRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.grok_realtime import (
|
||||
DograhGrokRealtimeLLMService,
|
||||
|
|
@ -120,7 +120,7 @@ async def test_completed_input_transcription_is_broadcast_as_finalized():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_grok_realtime_service():
|
||||
user_config = UserConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=GrokRealtimeLLMConfiguration(
|
||||
provider="grok_realtime",
|
||||
|
|
@ -131,7 +131,7 @@ def test_factory_creates_dograh_grok_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
|
|
@ -14,14 +15,14 @@ from api.services.configuration.registry import (
|
|||
)
|
||||
|
||||
|
||||
def _make_test_app():
|
||||
def _make_test_app(selected_organization_id=None):
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.is_superuser = False
|
||||
mock_user.selected_organization_id = None
|
||||
mock_user.selected_organization_id = selected_organization_id
|
||||
|
||||
app.dependency_overrides[get_user] = lambda: mock_user
|
||||
return app
|
||||
|
|
@ -32,7 +33,7 @@ MASKED_KEY = mask_key(REAL_KEY) # "**************************cdef"
|
|||
|
||||
|
||||
def _existing_openai_config():
|
||||
return UserConfiguration(
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key=REAL_KEY,
|
||||
|
|
@ -110,7 +111,7 @@ class TestMaskedKeyRejection:
|
|||
client = TestClient(app)
|
||||
|
||||
new_key = "AIzaSyNewRealKey12345678"
|
||||
updated = UserConfiguration(
|
||||
updated = EffectiveAIModelConfiguration(
|
||||
llm=GoogleLLMService(
|
||||
provider="google",
|
||||
api_key=new_key,
|
||||
|
|
@ -177,7 +178,7 @@ class TestMaskedKeyRejection:
|
|||
|
||||
real_credentials = '{"type":"service_account","project_id":"demo-project"}'
|
||||
masked_credentials = mask_key(real_credentials)
|
||||
existing = UserConfiguration(
|
||||
existing = EffectiveAIModelConfiguration(
|
||||
llm=GoogleVertexLLMConfiguration(
|
||||
provider="google_vertex",
|
||||
api_key=None,
|
||||
|
|
@ -210,3 +211,38 @@ class TestMaskedKeyRejection:
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_preference_only_update_does_not_validate_or_save_model_config(self):
|
||||
"""Saving a test phone number through the legacy endpoint must not touch models."""
|
||||
app = _make_test_app(selected_organization_id=11)
|
||||
client = TestClient(app)
|
||||
preferences = SimpleNamespace(test_phone_number=None, timezone=None)
|
||||
|
||||
with (
|
||||
patch("api.routes.user.db_client") as mock_db,
|
||||
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
|
||||
patch(
|
||||
"api.routes.user.get_organization_preferences",
|
||||
new=AsyncMock(return_value=preferences),
|
||||
),
|
||||
patch(
|
||||
"api.routes.user.upsert_organization_preferences",
|
||||
new=AsyncMock(return_value=preferences),
|
||||
) as upsert_preferences,
|
||||
):
|
||||
existing = _existing_openai_config()
|
||||
mock_db.get_user_configurations = AsyncMock(return_value=existing)
|
||||
mock_db.update_user_configuration = AsyncMock()
|
||||
mock_db.get_organization_by_id = AsyncMock(return_value=None)
|
||||
mock_validator.return_value.validate = AsyncMock()
|
||||
|
||||
response = client.put(
|
||||
"/user/configurations/user",
|
||||
json={"test_phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["test_phone_number"] == "+15551234567"
|
||||
mock_db.update_user_configuration.assert_not_called()
|
||||
mock_validator.return_value.validate.assert_not_called()
|
||||
upsert_preferences.assert_awaited_once()
|
||||
|
|
|
|||
|
|
@ -87,3 +87,387 @@ async def test_check_service_key_usage_uses_bearer_self_usage(monkeypatch):
|
|||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_correlation_id_uses_bearer_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"correlation_id": "mps-corr-123"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.create_correlation_id(
|
||||
service_key="mps_sk_paid",
|
||||
workflow_run_id=42,
|
||||
) == {"correlation_id": "mps-corr-123"}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
{"workflow_run_id": 42},
|
||||
{
|
||||
"Authorization": "Bearer mps_sk_paid",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(200, {"organization_id": 42, "billing_mode": "v2"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_billing_account_status(organization_id=42) == {
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/status",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.authorize_workflow_run_start(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key="mps_sk_paid",
|
||||
require_correlation_id=True,
|
||||
minimum_credits=0.1,
|
||||
metadata={"workflow_id": 7},
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/run-authorization",
|
||||
{
|
||||
"workflow_run_id": 88,
|
||||
"service_key": "mps_sk_paid",
|
||||
"require_correlation_id": True,
|
||||
"minimum_credits": 0.1,
|
||||
"metadata": {"workflow_id": 7},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, headers):
|
||||
calls.append(("GET", url, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.ensure_billing_account_v2(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": "0.0000",
|
||||
"currency": "USD",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/balance",
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credit_ledger_sends_page_and_limit(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def get(self, url, params, headers):
|
||||
calls.append(("GET", url, params, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.get_credit_ledger(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
) == {
|
||||
"account": {"organization_id": 42},
|
||||
"ledger_entries": [],
|
||||
"total_count": 0,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 0,
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"GET",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/ledger",
|
||||
{"page": 3, "limit": 25},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_uses_hosted_secret_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"correlation_id": "mps-corr-123",
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_platform_usage_sends_duration_without_correlation(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(200, {"metered": True})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.report_platform_usage(
|
||||
organization_id=42,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=123,
|
||||
metadata={"source": "workflow_run_completion"},
|
||||
) == {"metered": True}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/platform-usage",
|
||||
{
|
||||
"duration_seconds": 87.0,
|
||||
"workflow_run_id": 123,
|
||||
"metadata": {"source": "workflow_run_completion"},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
|
|
|||
99
api/tests/test_organization_usage_billing.py
Normal file
99
api/tests/test_organization_usage_billing.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.routes import organization_usage
|
||||
|
||||
|
||||
def test_is_mps_billing_v2_depends_only_on_account_mode():
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v2"}) is True
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "v1"}) is False
|
||||
assert organization_usage._is_mps_billing_v2({"billing_mode": "shadow"}) is False
|
||||
assert organization_usage._is_mps_billing_v2(None) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_mps_billing_account_status_uses_user_provider_id(monkeypatch):
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(provider_id="provider-123")
|
||||
|
||||
assert await organization_usage._get_mps_billing_account_status(user, 42) == {
|
||||
"billing_mode": "v2"
|
||||
}
|
||||
get_status.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_billing_credits_pages_v2_ledger(monkeypatch):
|
||||
monkeypatch.setattr(organization_usage, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
organization_usage,
|
||||
"_get_mps_billing_account_status",
|
||||
AsyncMock(return_value={"billing_mode": "v2"}),
|
||||
)
|
||||
get_ledger = AsyncMock(
|
||||
return_value={
|
||||
"account": {
|
||||
"id": 7,
|
||||
"organization_id": 42,
|
||||
"billing_mode": "v2",
|
||||
"cached_balance_credits": 250,
|
||||
"currency": "USD",
|
||||
},
|
||||
"ledger_entries": [
|
||||
{
|
||||
"id": 99,
|
||||
"entry_type": "grant",
|
||||
"origin": "account_creation",
|
||||
"credits_delta": 250,
|
||||
"balance_after": 250,
|
||||
"created_at": "2026-06-12T00:00:00Z",
|
||||
}
|
||||
],
|
||||
"total_debits_credits": 75,
|
||||
"total_count": 101,
|
||||
"page": 3,
|
||||
"limit": 25,
|
||||
"total_pages": 5,
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
organization_usage.mps_service_key_client,
|
||||
"get_credit_ledger",
|
||||
get_ledger,
|
||||
)
|
||||
|
||||
user = SimpleNamespace(
|
||||
provider_id="provider-123",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
response = await organization_usage.get_billing_credits(
|
||||
page=3,
|
||||
limit=25,
|
||||
user=user,
|
||||
)
|
||||
|
||||
get_ledger.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
page=3,
|
||||
limit=25,
|
||||
created_by="provider-123",
|
||||
)
|
||||
assert response.billing_version == "v2"
|
||||
assert response.total_credits_used == 75
|
||||
assert response.total_count == 101
|
||||
assert response.page == 3
|
||||
assert response.limit == 25
|
||||
assert response.total_pages == 5
|
||||
assert response.ledger_entries[0].id == 99
|
||||
66
api/tests/test_pre_call_fetch.py
Normal file
66
api/tests/test_pre_call_fetch.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from api.services.pipecat.pre_call_fetch import _extract_initial_context
|
||||
|
||||
|
||||
class TestExtractInitialContext:
|
||||
"""Tests for _extract_initial_context, the pre-call fetch response parser."""
|
||||
|
||||
def test_initial_context_nested_under_call_inbound(self):
|
||||
"""The canonical `initial_context` key nested under `call_inbound`."""
|
||||
response = {"call_inbound": {"initial_context": {"customer_name": "Jane"}}}
|
||||
assert _extract_initial_context(response) == {"customer_name": "Jane"}
|
||||
|
||||
def test_initial_context_at_top_level(self):
|
||||
"""The canonical `initial_context` key at the top level."""
|
||||
response = {"initial_context": {"customer_name": "Jane"}}
|
||||
assert _extract_initial_context(response) == {"customer_name": "Jane"}
|
||||
|
||||
def test_legacy_dynamic_variables_nested(self):
|
||||
"""The legacy `dynamic_variables` key still works nested under `call_inbound`."""
|
||||
response = {"call_inbound": {"dynamic_variables": {"customer_name": "Jane"}}}
|
||||
assert _extract_initial_context(response) == {"customer_name": "Jane"}
|
||||
|
||||
def test_legacy_dynamic_variables_at_top_level(self):
|
||||
"""The legacy `dynamic_variables` key still works at the top level."""
|
||||
response = {"dynamic_variables": {"customer_name": "Jane"}}
|
||||
assert _extract_initial_context(response) == {"customer_name": "Jane"}
|
||||
|
||||
def test_initial_context_takes_precedence_over_legacy(self):
|
||||
"""When both keys are present, `initial_context` wins."""
|
||||
response = {
|
||||
"call_inbound": {
|
||||
"initial_context": {"source": "new"},
|
||||
"dynamic_variables": {"source": "legacy"},
|
||||
}
|
||||
}
|
||||
assert _extract_initial_context(response) == {"source": "new"}
|
||||
|
||||
def test_falls_back_to_legacy_when_initial_context_not_a_dict(self):
|
||||
"""A non-dict `initial_context` falls back to `dynamic_variables`."""
|
||||
response = {
|
||||
"initial_context": None,
|
||||
"dynamic_variables": {"customer_name": "Jane"},
|
||||
}
|
||||
assert _extract_initial_context(response) == {"customer_name": "Jane"}
|
||||
|
||||
def test_nested_values_preserved(self):
|
||||
"""Nested objects pass through untouched for dot-notation access."""
|
||||
response = {
|
||||
"call_inbound": {
|
||||
"initial_context": {"customer": {"address": {"city": "LA"}}}
|
||||
}
|
||||
}
|
||||
assert _extract_initial_context(response) == {
|
||||
"customer": {"address": {"city": "LA"}}
|
||||
}
|
||||
|
||||
def test_empty_when_no_known_keys(self):
|
||||
"""A response with neither key yields an empty dict."""
|
||||
assert _extract_initial_context({"call_inbound": {"agent_id": 1}}) == {}
|
||||
|
||||
def test_empty_when_call_inbound_missing(self):
|
||||
"""No `call_inbound` and no top-level keys yields an empty dict."""
|
||||
assert _extract_initial_context({}) == {}
|
||||
|
||||
def test_non_dict_vars_yield_empty(self):
|
||||
"""A non-dict value under a known key yields an empty dict."""
|
||||
assert _extract_initial_context({"initial_context": "nope"}) == {}
|
||||
|
|
@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner():
|
|||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
"api.routes.public_agent.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner():
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
quota_mock.assert_awaited_once_with(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=501,
|
||||
)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
|
||||
|
|
@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
|
|||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
"api.routes.public_agent.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
|
|||
274
api/tests/test_public_embed_cors.py
Normal file
274
api/tests/test_public_embed_cors.py
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.public_embed import PublicEmbedCORSMiddleware, router
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://app.dograh.com"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(PublicEmbedCORSMiddleware, api_prefix="/api/v1")
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
_ACTIVE_TOKEN = SimpleNamespace(
|
||||
id=10,
|
||||
is_active=True,
|
||||
expires_at=None,
|
||||
allowed_domains=[],
|
||||
workflow_id=1,
|
||||
created_by=7,
|
||||
usage_limit=None,
|
||||
usage_count=0,
|
||||
settings={},
|
||||
)
|
||||
|
||||
_RESTRICTED_TOKEN = SimpleNamespace(
|
||||
id=20,
|
||||
is_active=True,
|
||||
expires_at=None,
|
||||
allowed_domains=["allowed.example.com"],
|
||||
workflow_id=2,
|
||||
created_by=7,
|
||||
usage_limit=None,
|
||||
usage_count=0,
|
||||
settings={},
|
||||
)
|
||||
|
||||
_LOCALHOST_TOKEN = SimpleNamespace(
|
||||
id=30,
|
||||
is_active=True,
|
||||
expires_at=None,
|
||||
allowed_domains=["localhost:3000", "localhost:3020"],
|
||||
workflow_id=3,
|
||||
created_by=7,
|
||||
usage_limit=None,
|
||||
usage_count=0,
|
||||
settings={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_db(monkeypatch):
|
||||
async def _get_token(token):
|
||||
if token == "valid":
|
||||
return _ACTIVE_TOKEN
|
||||
if token == "restricted":
|
||||
return _RESTRICTED_TOKEN
|
||||
if token == "localhost":
|
||||
return _LOCALHOST_TOKEN
|
||||
return None
|
||||
|
||||
async def _get_token_by_id(token_id):
|
||||
if token_id == _ACTIVE_TOKEN.id:
|
||||
return _ACTIVE_TOKEN
|
||||
if token_id == _RESTRICTED_TOKEN.id:
|
||||
return _RESTRICTED_TOKEN
|
||||
if token_id == _LOCALHOST_TOKEN.id:
|
||||
return _LOCALHOST_TOKEN
|
||||
return None
|
||||
|
||||
async def _get_session(session_token):
|
||||
if session_token == "session-valid":
|
||||
return SimpleNamespace(embed_token_id=_ACTIVE_TOKEN.id, expires_at=None)
|
||||
if session_token == "session-restricted":
|
||||
return SimpleNamespace(embed_token_id=_RESTRICTED_TOKEN.id, expires_at=None)
|
||||
return None
|
||||
|
||||
async def _create_workflow_run(**_kwargs):
|
||||
return SimpleNamespace(id=123)
|
||||
|
||||
async def _noop(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.db_client.get_embed_token_by_token",
|
||||
_get_token,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.db_client.get_embed_token_by_id",
|
||||
_get_token_by_id,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.db_client.get_embed_session_by_token",
|
||||
_get_session,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.db_client.create_workflow_run",
|
||||
_create_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.db_client.create_embed_session",
|
||||
_noop,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.db_client.increment_embed_token_usage",
|
||||
_noop,
|
||||
)
|
||||
monkeypatch.setattr("api.routes.public_embed.TURN_SECRET", "test-secret")
|
||||
monkeypatch.setattr(
|
||||
"api.routes.public_embed.generate_turn_credentials",
|
||||
lambda _user_id: {
|
||||
"username": "turn-user",
|
||||
"password": "turn-password",
|
||||
"ttl": 3600,
|
||||
"uris": ["turn:example.com:3478"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _assert_embed_cors(resp, origin: str):
|
||||
assert resp.headers.get("access-control-allow-origin") == origin
|
||||
assert "origin" in {
|
||||
value.strip().lower() for value in resp.headers.get("vary", "").split(",")
|
||||
}
|
||||
|
||||
|
||||
def test_options_config_returns_acao_for_allowed_origin():
|
||||
origin = "https://mysite.vercel.app"
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/config/valid",
|
||||
headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_options_config_accepts_allowed_localhost_port():
|
||||
origin = "http://localhost:3020"
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/config/localhost",
|
||||
headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_options_config_rejects_unknown_token():
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/config/unknown",
|
||||
headers={
|
||||
"Origin": "https://mysite.vercel.app",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_options_config_rejects_disallowed_origin():
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/config/restricted",
|
||||
headers={
|
||||
"Origin": "https://notallowed.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_get_config_includes_acao_header():
|
||||
origin = "https://mysite.vercel.app"
|
||||
resp = client.get(
|
||||
"/api/v1/public/embed/config/valid",
|
||||
headers={"Origin": origin},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_get_config_accepts_allowed_localhost_port():
|
||||
origin = "http://localhost:3020"
|
||||
resp = client.get(
|
||||
"/api/v1/public/embed/config/localhost",
|
||||
headers={"Origin": origin},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_get_config_rejects_unlisted_localhost_port():
|
||||
resp = client.get(
|
||||
"/api/v1/public/embed/config/localhost",
|
||||
headers={"Origin": "http://localhost:3021"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_get_config_rejects_disallowed_origin():
|
||||
resp = client.get(
|
||||
"/api/v1/public/embed/config/restricted",
|
||||
headers={"Origin": "https://notallowed.example.com"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_init_includes_acao_header():
|
||||
origin = "https://mysite.vercel.app"
|
||||
resp = client.post(
|
||||
"/api/v1/public/embed/init",
|
||||
headers={"Origin": origin},
|
||||
json={"token": "valid"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_turn_credentials_includes_acao_header():
|
||||
origin = "https://mysite.vercel.app"
|
||||
resp = client.get(
|
||||
"/api/v1/public/embed/turn-credentials/session-valid",
|
||||
headers={"Origin": origin},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_options_init_returns_acao_for_allowed_origin():
|
||||
origin = "https://mysite.vercel.app"
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/init",
|
||||
headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_options_turn_credentials_returns_acao_for_allowed_origin():
|
||||
origin = "https://mysite.vercel.app"
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/turn-credentials/session-valid",
|
||||
headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
_assert_embed_cors(resp, origin)
|
||||
|
||||
|
||||
def test_options_turn_credentials_rejects_disallowed_origin():
|
||||
resp = client.options(
|
||||
"/api/v1/public/embed/turn-credentials/session-restricted",
|
||||
headers={
|
||||
"Origin": "https://notallowed.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
369
api/tests/test_quota_service.py
Normal file
369
api/tests/test_quota_service.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import quota_service
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
|
||||
|
||||
def _dograh_config(
|
||||
api_key: str = "mps_sk_12345678",
|
||||
*,
|
||||
managed_service_version: int = 1,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
managed_service_version=managed_service_version,
|
||||
llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key),
|
||||
stt=None,
|
||||
tts=None,
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
def _byok_config():
|
||||
return SimpleNamespace(
|
||||
managed_service_version=2,
|
||||
llm=SimpleNamespace(provider="openai", api_key="sk-openai"),
|
||||
stt=None,
|
||||
tts=None,
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
def _workflow():
|
||||
return SimpleNamespace(
|
||||
id=7,
|
||||
user_id=123,
|
||||
organization_id=42,
|
||||
workflow_configurations={"model_overrides": {}},
|
||||
)
|
||||
|
||||
|
||||
def _workflow_owner():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
provider_id="provider-123",
|
||||
)
|
||||
|
||||
|
||||
def _actor():
|
||||
return SimpleNamespace(
|
||||
id=456,
|
||||
provider_id="actor-456",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None):
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_by_id",
|
||||
AsyncMock(return_value=workflow or _workflow()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_user_by_id",
|
||||
AsyncMock(return_value=owner or _workflow_owner()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2(
|
||||
monkeypatch,
|
||||
):
|
||||
get_config = AsyncMock(return_value=_dograh_config())
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is True
|
||||
get_config.assert_awaited_once_with(
|
||||
user_id=123,
|
||||
organization_id=42,
|
||||
workflow_configurations={"model_overrides": {}},
|
||||
)
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=None,
|
||||
service_key=None,
|
||||
require_correlation_id=False,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
check_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
get_config = AsyncMock(return_value=_byok_config())
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": False,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "0.0000",
|
||||
"error": "insufficient_credits",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "insufficient_credits"
|
||||
assert "/billing" in result.error_message
|
||||
assert "founders@dograh.com" not in result.error_message
|
||||
authorize.assert_awaited_once()
|
||||
check_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_v1_uses_legacy_key_usage(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
get_config = AsyncMock(return_value=_dograh_config(api_key))
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v1",
|
||||
"remaining_credits": "0.0000",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock(
|
||||
return_value={"total_credits_used": 500.0, "remaining_credits": 0.0}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "quota_exceeded"
|
||||
assert "founders@dograh.com" in result.error_message
|
||||
assert "/billing" not in result.error_message
|
||||
authorize.assert_awaited_once()
|
||||
check_usage.assert_awaited_once_with(
|
||||
api_key,
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
workflow_run = SimpleNamespace(initial_context={"existing": "value"})
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
}
|
||||
)
|
||||
update_workflow_run = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"update_workflow_run",
|
||||
update_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is True
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key=api_key,
|
||||
require_correlation_id=True,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
update_workflow_run.assert_awaited_once_with(
|
||||
88,
|
||||
initial_context={
|
||||
"existing": "value",
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
workflow_run = SimpleNamespace(initial_context={})
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
hosted_authorize = AsyncMock()
|
||||
check_usage = AsyncMock(
|
||||
return_value={"total_credits_used": 1.0, "remaining_credits": 499.0}
|
||||
)
|
||||
create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"})
|
||||
update_workflow_run = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"update_workflow_run",
|
||||
update_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
hosted_authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"create_correlation_id",
|
||||
create_correlation,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is True
|
||||
hosted_authorize.assert_not_awaited()
|
||||
check_usage.assert_awaited_once_with(
|
||||
api_key,
|
||||
organization_id=None,
|
||||
created_by="provider-123",
|
||||
)
|
||||
create_correlation.assert_awaited_once_with(
|
||||
service_key=api_key,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
update_workflow_run.assert_awaited_once_with(
|
||||
88,
|
||||
initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch):
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
actor_user=SimpleNamespace(selected_organization_id=999),
|
||||
)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "workflow_not_found"
|
||||
|
|
@ -2,14 +2,14 @@
|
|||
TDD tests for resolve_effective_config().
|
||||
|
||||
This function deep-merges workflow-level model_overrides onto the global
|
||||
UserConfiguration. Fields not overridden inherit from global.
|
||||
EffectiveAIModelConfiguration. Fields not overridden inherit from global.
|
||||
|
||||
Module under test: api.services.configuration.resolve
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
contains_masked_key,
|
||||
mask_workflow_configurations,
|
||||
|
|
@ -35,9 +35,9 @@ from api.services.configuration.resolve import (
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def global_config() -> UserConfiguration:
|
||||
def global_config() -> EffectiveAIModelConfiguration:
|
||||
"""A realistic global user configuration."""
|
||||
return UserConfiguration(
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai", api_key="sk-global-llm", model="gpt-4.1"
|
||||
),
|
||||
|
|
@ -59,9 +59,9 @@ def global_config() -> UserConfiguration:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def global_config_realtime() -> UserConfiguration:
|
||||
def global_config_realtime() -> EffectiveAIModelConfiguration:
|
||||
"""Global config with realtime enabled."""
|
||||
return UserConfiguration(
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai", api_key="sk-global-llm", model="gpt-4.1"
|
||||
),
|
||||
|
|
@ -302,7 +302,7 @@ class TestRealtimeOverride:
|
|||
class TestOverrideOnNullGlobal:
|
||||
def test_override_stt_when_global_is_none(self):
|
||||
"""When global has no STT config, override creates one from scratch."""
|
||||
config = UserConfiguration(
|
||||
config = EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(provider="openai", api_key="sk-key", model="gpt-4.1"),
|
||||
stt=None,
|
||||
tts=None,
|
||||
|
|
@ -325,7 +325,7 @@ class TestOverrideOnNullGlobal:
|
|||
|
||||
def test_override_realtime_when_global_is_none(self):
|
||||
"""Realtime section can be created from override even if global has none."""
|
||||
config = UserConfiguration(
|
||||
config = EffectiveAIModelConfiguration(
|
||||
llm=OpenAILLMService(provider="openai", api_key="sk-key", model="gpt-4.1"),
|
||||
is_realtime=False,
|
||||
realtime=None,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.workflow.run_usage_response import format_public_usage_info
|
||||
|
||||
|
||||
def test_format_public_usage_info():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
quota_mock.assert_awaited_once_with(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=501,
|
||||
actor_user=ANY,
|
||||
)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_call = mock_db.create_workflow_run.await_args
|
||||
|
|
@ -103,6 +107,61 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
assert initiate_kwargs["workflow_id"] == workflow.id
|
||||
assert initiate_kwargs["user_id"] == workflow.user_id
|
||||
assert "user_id=99" in initiate_kwargs["webhook_url"]
|
||||
mock_db.get_user_configurations.assert_not_called()
|
||||
|
||||
|
||||
def test_initiate_call_uses_organization_preference_phone_number():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _workflow()
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_backend_endpoints",
|
||||
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
|
||||
),
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(
|
||||
return_value=SimpleNamespace(test_phone_number="+15550000000")
|
||||
)
|
||||
mock_db.get_configuration = Mock(
|
||||
return_value=SimpleNamespace(value={"test_phone_number": "+15557654321"})
|
||||
)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.get_workflow = AsyncMock(return_value=workflow)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=501,
|
||||
name="WR-TEL-OUT-00000001",
|
||||
initial_context={},
|
||||
)
|
||||
)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/telephony/initiate-call",
|
||||
json={"workflow_id": workflow.id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert provider.initiate_call.await_args.kwargs["to_number"] == "+15557654321"
|
||||
mock_db.get_user_configurations.assert_not_called()
|
||||
|
||||
|
||||
def test_initiate_call_rejects_existing_run_for_different_workflow():
|
||||
|
|
@ -118,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from pipecat.processors.frame_processor import FrameDirection
|
|||
from websockets.exceptions import ConnectionClosedError
|
||||
from websockets.frames import Close
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import UltravoxRealtimeLLMConfiguration
|
||||
from api.services.pipecat.realtime.ultravox_realtime import (
|
||||
_RESUMPTION_USER_MESSAGE,
|
||||
|
|
@ -430,7 +430,7 @@ async def test_receive_messages_reports_unexpected_websocket_close():
|
|||
|
||||
|
||||
def test_factory_creates_dograh_ultravox_realtime_service():
|
||||
user_config = UserConfiguration(
|
||||
effective_config = EffectiveAIModelConfiguration(
|
||||
is_realtime=True,
|
||||
realtime=UltravoxRealtimeLLMConfiguration(
|
||||
provider="ultravox_realtime",
|
||||
|
|
@ -441,7 +441,7 @@ def test_factory_creates_dograh_ultravox_realtime_service():
|
|||
)
|
||||
|
||||
service = create_realtime_llm_service(
|
||||
user_config,
|
||||
effective_config,
|
||||
audio_config=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
|
|
|||
212
api/tests/test_workflow_run_billing.py
Normal file
212
api/tests/test_workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import workflow_run_billing as workflow_run_billing_mod
|
||||
from api.services.workflow_run_billing import (
|
||||
report_completed_workflow_run_platform_usage,
|
||||
report_workflow_run_platform_usage,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
workflow_id=456,
|
||||
is_completed=True,
|
||||
initial_context={"mps_correlation_id": "mps-corr-123"},
|
||||
usage_info={"call_duration_seconds": 87},
|
||||
workflow=SimpleNamespace(
|
||||
organization_id=42,
|
||||
user=SimpleNamespace(selected_organization_id=42),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_hosted_completion(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id="mps-corr-123",
|
||||
duration_seconds=None,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "mps_correlation",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_reports_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
correlation_id=None,
|
||||
duration_seconds=87.0,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"duration_source": "dograh_usage_info",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_non_v2_account(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v1"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_awaited_once_with(organization_id=42)
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_missing_duration_without_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.initial_context = {}
|
||||
workflow_run.usage_info = {}
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
get_status.assert_not_awaited()
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_oss(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "oss")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_workflow_run_platform_usage_skips_incomplete(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
workflow_run.is_completed = False
|
||||
report_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
|
||||
report_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_completed_workflow_run_platform_usage_loads_run(monkeypatch):
|
||||
workflow_run = _make_workflow_run()
|
||||
get_run = AsyncMock(return_value=workflow_run)
|
||||
get_status = AsyncMock(return_value={"billing_mode": "v2"})
|
||||
report_usage = AsyncMock(return_value={"metered": True})
|
||||
|
||||
monkeypatch.setattr(workflow_run_billing_mod, "DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
get_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"get_billing_account_status",
|
||||
get_status,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_run_billing_mod.mps_service_key_client,
|
||||
"report_platform_usage",
|
||||
report_usage,
|
||||
)
|
||||
|
||||
await report_completed_workflow_run_platform_usage(workflow_run.id)
|
||||
|
||||
get_run.assert_awaited_once_with(workflow_run.id)
|
||||
report_usage.assert_awaited_once()
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue