Merge remote-tracking branch 'origin/main' into feat/user-onboarding

This commit is contained in:
Abhishek Kumar 2026-06-12 18:54:48 +05:30
commit 093e888ce4
148 changed files with 10908 additions and 2815 deletions

View file

@ -1,7 +1,7 @@
from typing import Annotated, Optional
import httpx
from fastapi import Header, HTTPException, Query, WebSocket
from fastapi import Depends, Header, HTTPException, Query, WebSocket
from loguru import logger
from pydantic import ValidationError
@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.auth.stack_auth import stackauth
from api.services.configuration.registry import ServiceProviders
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.posthog_client import capture_event
from api.utils.auth import decode_jwt_token
@ -110,6 +111,19 @@ async def get_user(
# This prevents race conditions where multiple concurrent requests
# might try to create configurations
if org_was_created:
try:
await ensure_hosted_mps_billing_account_v2(
organization.id,
created_by=str(stack_user["id"]),
)
except Exception:
logger.warning(
"Failed to initialize hosted MPS billing account for "
"organization {}",
organization.id,
exc_info=True,
)
existing_cfg = await db_client.get_user_configurations(user_model.id)
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
mps_config = await create_user_configuration_with_mps_key(
@ -119,6 +133,19 @@ async def get_user(
await db_client.update_user_configuration(
user_model.id, mps_config
)
from api.enums import OrganizationConfigurationKey
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
mps_config
)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception as exc:
raise HTTPException(
@ -129,6 +156,14 @@ async def get_user(
return user_model
async def get_user_with_selected_organization(
user: Annotated[UserModel, Depends(get_user)],
) -> UserModel:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
return user
async def _handle_oss_auth(authorization: str | None) -> UserModel:
"""
Handle authentication for OSS deployment mode.
@ -192,7 +227,7 @@ async def _handle_api_key_auth(api_key: str) -> UserModel:
async def create_user_configuration_with_mps_key(
user_id: int, organization_id: int, user_provider_id: str
) -> Optional[UserConfiguration]:
) -> Optional[EffectiveAIModelConfiguration]:
"""Create user configuration using MPS service key.
Args:
@ -201,7 +236,7 @@ async def create_user_configuration_with_mps_key(
user_provider_id: The user's provider ID (for created_by field)
Returns:
UserConfiguration with MPS-provided API keys or None if failed
EffectiveAIModelConfiguration with MPS-provided API keys or None if failed
"""
async with httpx.AsyncClient() as client:
@ -211,7 +246,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": "Auto-generated key for OSS user",
"expires_in_days": 7, # Short-lived for OSS
"created_by": user_provider_id,
@ -229,7 +264,7 @@ async def create_user_configuration_with_mps_key(
response = await client.post(
f"{MPS_API_URL}/api/v1/service-keys/",
json={
"name": f"Default Dograh Model Service Key",
"name": "Default Dograh Model Service Key",
"description": f"Auto-generated key for organization {organization_id}",
"organization_id": organization_id,
"expires_in_days": 90, # Longer-lived for authenticated users
@ -264,8 +299,8 @@ async def create_user_configuration_with_mps_key(
"model": "default",
},
}
user_config = UserConfiguration(**configuration)
return user_config
effective_config = EffectiveAIModelConfiguration(**configuration)
return effective_config
else:
logger.warning(
f"Failed to get MPS service key: {response.status_code} - {response.text}"

View file

@ -15,6 +15,7 @@ from api.services.campaign.errors import (
PhoneNumberPoolExhaustedError,
)
from api.services.campaign.rate_limiter import rate_limiter
from api.services.quota_service import authorize_workflow_run_start
from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
},
)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
error_message = quota_result.error_message or "Quota exceeded"
logger.warning(
f"Campaign {campaign.id} quota check failed for workflow run "
f"{workflow_run.id}: {error_message}"
)
await db_client.update_workflow_run(
run_id=workflow_run.id,
is_completed=True,
state=WorkflowRunState.COMPLETED.value,
gathered_context={"error": error_message},
)
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
if mapping:
org_id, mapped_slot_id = mapping
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
workflow_run.id
)
if from_number_mapping:
fn_org_id, fn_number, fn_tcid = from_number_mapping
await rate_limiter.release_from_number(
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
)
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
raise ValueError(error_message)
# Initiate call via telephony provider
try:
# Construct webhook URL with parameters

View file

@ -0,0 +1,484 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Literal
from loguru import logger
from pydantic import ValidationError
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from api.constants import MPS_API_URL
from api.db import db_client
from api.db.models import WorkflowDefinitionModel, WorkflowModel
from api.enums import OrganizationConfigurationKey
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
BYOKAIModelConfiguration,
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
EffectiveAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,
mask_key,
resolve_masked_api_keys,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
@dataclass
class ResolvedAIModelConfiguration:
effective: EffectiveAIModelConfiguration
source: AIModelConfigurationSource
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
@dataclass
class WorkflowAIModelConfigurationMigrationResult:
workflow_count: int = 0
definition_count: int = 0
workflow_ids: list[int] | None = None
async def get_resolved_ai_model_configuration(
*,
user_id: int | None,
organization_id: int | None,
) -> ResolvedAIModelConfiguration:
organization_configuration = await get_organization_ai_model_configuration_v2(
organization_id
)
if organization_configuration is not None:
return ResolvedAIModelConfiguration(
effective=compile_ai_model_configuration_v2(organization_configuration),
source="organization_v2",
organization_configuration=organization_configuration,
)
if user_id is None:
return ResolvedAIModelConfiguration(
effective=EffectiveAIModelConfiguration(),
source="empty",
)
legacy = await db_client.get_user_configurations(user_id)
return ResolvedAIModelConfiguration(
effective=legacy,
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
)
async def get_effective_ai_model_configuration_for_workflow(
*,
user_id: int | None,
organization_id: int | None,
workflow_configurations: dict | None,
) -> EffectiveAIModelConfiguration:
workflow_configurations = workflow_configurations or {}
v2_override = workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
if v2_override:
return compile_ai_model_configuration_v2(
OrganizationAIModelConfigurationV2.model_validate(v2_override)
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user_id,
organization_id=organization_id,
)
return resolve_effective_config(
resolved_config.effective,
workflow_configurations.get("model_overrides"),
)
async def get_organization_ai_model_configuration_v2(
organization_id: int | None,
) -> OrganizationAIModelConfigurationV2 | None:
if organization_id is None:
return None
row = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
)
if row is None or not row.value:
return None
try:
return OrganizationAIModelConfigurationV2.model_validate(row.value)
except ValidationError as exc:
logger.warning(
"Invalid org AI model configuration v2 for organization "
f"{organization_id}: {exc}. Falling back to legacy configuration."
)
return None
async def upsert_organization_ai_model_configuration_v2(
organization_id: int,
configuration: OrganizationAIModelConfigurationV2,
) -> OrganizationAIModelConfigurationV2:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
configuration.model_dump(mode="json", exclude_none=True),
)
return configuration
async def migrate_workflow_model_configurations_to_v2(
*,
organization_id: int,
fallback_user_config: EffectiveAIModelConfiguration,
) -> WorkflowAIModelConfigurationMigrationResult:
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
workflow_updates: list[tuple[int, dict]] = []
definition_updates: list[tuple[int, dict]] = []
migrated_workflow_ids: set[int] = set()
for workflow in workflows:
base_config = fallback_user_config
if workflow.user_id is not None:
if workflow.user_id not in owner_configs:
owner_configs[
workflow.user_id
] = await db_client.get_user_configurations(workflow.user_id)
base_config = owner_configs[workflow.user_id]
workflow_configs, workflow_changed = (
migrate_workflow_configuration_model_override_to_v2(
workflow.workflow_configurations,
base_config,
)
)
if workflow_changed:
workflow_updates.append((workflow.id, workflow_configs))
migrated_workflow_ids.add(workflow.id)
for definition in workflow.definitions:
definition_configs, definition_changed = (
migrate_workflow_configuration_model_override_to_v2(
definition.workflow_configurations,
base_config,
)
)
if definition_changed:
definition_updates.append((definition.id, definition_configs))
migrated_workflow_ids.add(workflow.id)
if workflow_updates or definition_updates:
async with db_client.async_session() as session:
for workflow_id, workflow_configs in workflow_updates:
await session.execute(
update(WorkflowModel)
.where(WorkflowModel.id == workflow_id)
.values(workflow_configurations=workflow_configs)
)
for definition_id, definition_configs in definition_updates:
await session.execute(
update(WorkflowDefinitionModel)
.where(WorkflowDefinitionModel.id == definition_id)
.values(workflow_configurations=definition_configs)
)
await session.commit()
return WorkflowAIModelConfigurationMigrationResult(
workflow_count=len(migrated_workflow_ids),
definition_count=len(definition_updates),
workflow_ids=sorted(migrated_workflow_ids),
)
def migrate_workflow_configuration_model_override_to_v2(
workflow_configurations: dict | None,
base_config: EffectiveAIModelConfiguration,
) -> tuple[dict, bool]:
if not isinstance(workflow_configurations, dict):
return {}, False
migrated = copy.deepcopy(workflow_configurations)
model_overrides = migrated.get("model_overrides")
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
if not isinstance(model_overrides, dict):
if "model_overrides" in migrated:
migrated.pop("model_overrides", None)
return migrated, True
return migrated, False
if not existing_v2_override:
effective = resolve_effective_config(base_config, model_overrides)
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
mode="json", exclude_none=True
)
migrated.pop("model_overrides", None)
return migrated, True
def merge_ai_model_configuration_v2_secrets(
incoming: OrganizationAIModelConfigurationV2,
existing: OrganizationAIModelConfigurationV2 | None,
) -> OrganizationAIModelConfigurationV2:
if existing is None:
return incoming
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
existing_dict = existing.model_dump(mode="json", exclude_none=True)
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
incoming_dograh = incoming_dict.get("dograh") or {}
existing_dograh = existing_dict.get("dograh") or {}
incoming_key = incoming_dograh.get("api_key")
existing_key = existing_dograh.get("api_key")
if incoming_key and existing_key and contains_masked_key(incoming_key):
incoming_dograh["api_key"] = resolve_masked_api_keys(
incoming_key,
existing_key,
)
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
def check_for_masked_keys_in_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> None:
data = configuration.model_dump(mode="json", exclude_none=True)
_raise_if_masked_secret(data)
def mask_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2 | None,
) -> dict | None:
if configuration is None:
return None
data = configuration.model_dump(mode="json", exclude_none=True)
_mask_secret_fields(data)
return data
def convert_legacy_ai_model_configuration_to_v2(
configuration: EffectiveAIModelConfiguration,
) -> OrganizationAIModelConfigurationV2:
dograh_key = _first_dograh_api_key(configuration)
if dograh_key:
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
if configuration.is_realtime:
if configuration.realtime is None or configuration.llm is None:
raise ValueError("Realtime legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="realtime",
realtime=BYOKRealtimeAIModelConfiguration(
realtime=configuration.realtime,
llm=configuration.llm,
embeddings=configuration.embeddings,
),
),
)
if (
configuration.llm is None
or configuration.tts is None
or configuration.stt is None
):
raise ValueError("Pipeline legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="pipeline",
pipeline=BYOKPipelineAIModelConfiguration(
llm=configuration.llm,
tts=configuration.tts,
stt=configuration.stt,
embeddings=configuration.embeddings,
),
),
)
def dograh_embeddings_base_url() -> str:
return f"{MPS_API_URL}/api/v1/llm"
def apply_managed_embeddings_base_url(
*,
provider: str | None,
base_url: str | None,
) -> str | None:
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
return dograh_embeddings_base_url()
return base_url
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
return
incoming_mode = incoming_byok.get("mode")
existing_mode = existing_byok.get("mode")
if incoming_mode != existing_mode:
return
section_names = (
("llm", "tts", "stt", "embeddings")
if incoming_mode == "pipeline"
else ("realtime", "llm", "embeddings")
)
incoming_container = incoming_byok.get(incoming_mode)
existing_container = existing_byok.get(existing_mode)
if not isinstance(incoming_container, dict) or not isinstance(
existing_container, dict
):
return
for section_name in section_names:
incoming_section = incoming_container.get(section_name)
existing_section = existing_container.get(section_name)
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
_merge_service_secret_fields(incoming_section, existing_section)
async def _list_workflows_for_model_configuration_migration(
organization_id: int,
) -> list[WorkflowModel]:
async with db_client.async_session() as session:
result = await session.execute(
select(WorkflowModel)
.options(selectinload(WorkflowModel.definitions))
.where(WorkflowModel.organization_id == organization_id)
)
return list(result.scalars().unique().all())
def _merge_service_secret_fields(incoming: dict, existing: dict):
if (
incoming.get("provider") is not None
and existing.get("provider") is not None
and incoming.get("provider") != existing.get("provider")
):
return
for secret_field in SERVICE_SECRET_FIELDS:
if secret_field not in existing:
continue
incoming_secret = incoming.get(secret_field)
existing_secret = existing[secret_field]
if incoming_secret is None:
incoming[secret_field] = existing_secret
elif contains_masked_key(incoming_secret):
incoming[secret_field] = resolve_masked_api_keys(
incoming_secret,
existing_secret,
)
def _raise_if_masked_secret(value):
if isinstance(value, dict):
for key, nested in value.items():
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
raise ValueError(
f"The {key} appears to be masked. Please provide the actual "
"value, not the masked value."
)
_raise_if_masked_secret(nested)
elif isinstance(value, list):
for item in value:
_raise_if_masked_secret(item)
def _mask_secret_fields(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_secret_fields(nested)
elif isinstance(value, list):
for item in value:
_mask_secret_fields(item)
def _mask_secret_value(value):
if isinstance(value, list):
return [mask_key(item) for item in value]
return mask_key(value)
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
return any(
service is not None
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
)
)
def _convert_any_dograh_legacy_configuration(
configuration: EffectiveAIModelConfiguration,
dograh_key: str,
) -> OrganizationAIModelConfigurationV2:
speed = getattr(configuration.tts, "speed", 1.0)
if speed not in DOGRAH_SPEED_OPTIONS:
speed = 1.0
return OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key=dograh_key,
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
or DOGRAH_DEFAULT_VOICE,
speed=speed,
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
or DOGRAH_DEFAULT_LANGUAGE,
),
)
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
):
if service is None or _provider(service) != ServiceProviders.DOGRAH:
continue
try:
return _single_api_key(service)
except ValueError:
continue
return None
def _provider(service):
return getattr(service, "provider", None)
def _single_api_key(service) -> str:
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if len(keys) != 1:
raise ValueError("Expected exactly one API key")
return keys[0]
key = getattr(service, "api_key", None)
if not key:
raise ValueError("Expected an API key")
return key

View file

@ -8,8 +8,8 @@ from groq import Groq
# from pyneuphonic import Neuphonic
# except ImportError:
# Neuphonic = None
from api.schemas.user_configuration import (
UserConfiguration,
from api.schemas.ai_model_configuration import (
EffectiveAIModelConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
@ -64,7 +64,7 @@ class UserConfigurationValidator:
async def validate(
self,
configuration: UserConfiguration,
configuration: EffectiveAIModelConfiguration,
organization_id: Optional[int] = None,
created_by: Optional[str] = None,
) -> APIKeyStatusResponse:
@ -75,21 +75,21 @@ class UserConfigurationValidator:
status_list = []
status_list.extend(self._validate_service(configuration.llm, "llm"))
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
# Realtime is optional - only validate if is_realtime is enabled
if configuration.is_realtime:
status_list.extend(
self._validate_service(
configuration.realtime, "realtime", required=True
)
)
else:
status_list.extend(self._validate_service(configuration.stt, "stt"))
status_list.extend(self._validate_service(configuration.tts, "tts"))
# Embeddings is optional - only validate if configured
status_list.extend(
self._validate_service(
configuration.embeddings, "embeddings", required=False
)
)
if status_list:
raise ValueError(status_list)

View file

@ -12,7 +12,7 @@ The rules are simple:
import copy
from typing import Any, Dict, Optional
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceConfig
from api.services.integrations import get_node_secret_fields
@ -31,7 +31,7 @@ def contains_masked_key(value: str | list[str] | None) -> bool:
return any(MASK_MARKER in k for k in keys)
def check_for_masked_keys(config: "UserConfiguration") -> None:
def check_for_masked_keys(config: "EffectiveAIModelConfiguration") -> None:
"""Raise ValueError if any service in *config* still has a masked secret."""
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
service = getattr(config, field, None)
@ -111,7 +111,7 @@ def resolve_masked_api_keys(
# ---------------------------------------------------------------------------
# High-level helpers for UserConfiguration objects
# High-level helpers for EffectiveAIModelConfiguration objects
# ---------------------------------------------------------------------------
@ -129,7 +129,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
return data
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
def mask_user_config(config: EffectiveAIModelConfiguration) -> Dict[str, Any]:
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
return {
@ -155,21 +155,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
masked = copy.deepcopy(config)
model_overrides = masked.get("model_overrides")
if not isinstance(model_overrides, dict):
return masked
if isinstance(model_overrides, dict):
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
v2_override = masked.get("model_configuration_v2_override")
if isinstance(v2_override, dict):
_mask_nested_service_secrets(v2_override)
return masked
def _mask_nested_service_secrets(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_nested_service_secrets(nested)
elif isinstance(value, list):
for item in value:
_mask_nested_service_secrets(item)
# ---------------------------------------------------------------------------
# Workflow definition helpers mask / merge node API keys
# ---------------------------------------------------------------------------

View file

@ -7,7 +7,7 @@ stored, while honouring masked API keys.
import copy
from typing import Dict
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
MODEL_OVERRIDE_FIELDS,
SERVICE_SECRET_FIELDS,
@ -66,9 +66,9 @@ def _merge_service_secret_fields(
def merge_user_configurations(
existing: UserConfiguration, incoming_partial: Dict[str, dict]
) -> UserConfiguration:
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
existing: EffectiveAIModelConfiguration, incoming_partial: Dict[str, dict]
) -> EffectiveAIModelConfiguration:
"""Merge *incoming_partial* onto *existing* and return a new EffectiveAIModelConfiguration.
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
extracted via Pydantic `model_dump`).
@ -113,14 +113,14 @@ def merge_user_configurations(
if "timezone" in incoming_partial:
merged["timezone"] = incoming_partial["timezone"]
# Onboarding gate flags — overwrite only when supplied (set once on submit/skip).
# Onboarding gate flags: overwrite only when supplied.
if "onboarding_completed_at" in incoming_partial:
merged["onboarding_completed_at"] = incoming_partial["onboarding_completed_at"]
if "onboarding_skipped" in incoming_partial:
merged["onboarding_skipped"] = incoming_partial["onboarding_skipped"]
return UserConfiguration.model_validate(merged)
return EffectiveAIModelConfiguration.model_validate(merged)
def merge_workflow_configuration_secrets(

View file

@ -911,7 +911,7 @@ class DograhTTSService(BaseTTSConfiguration):
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice.")
CARTESIA_TTS_MODELS = ["sonic-3"]
CARTESIA_TTS_MODELS = ["sonic-3.5", "sonic-3"]
@register_tts
@ -919,7 +919,7 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
model: str = Field(
default="sonic-3",
default="sonic-3.5",
description="Cartesia TTS model.",
json_schema_extra={"examples": CARTESIA_TTS_MODELS},
)
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
)
DOGRAH_EMBEDDING_MODELS = ["default"]
@register_embeddings
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: str = Field(
default="default",
description="Dograh-managed embedding model.",
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
)
EmbeddingsConfig = Annotated[
Union[
OpenAIEmbeddingsConfiguration,
OpenRouterEmbeddingsConfiguration,
AzureOpenAIEmbeddingsConfiguration,
DograhEmbeddingsConfiguration,
],
Field(discriminator="provider"),
]

View file

@ -4,13 +4,13 @@ from __future__ import annotations
import copy
from api.schemas.user_configuration import UserConfiguration
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
REGISTRY,
ServiceType,
)
# Maps override key → (UserConfiguration field, ServiceType for registry lookup)
# Maps override key → (EffectiveAIModelConfiguration field, ServiceType for registry lookup)
_SECTION_MAP: dict[str, ServiceType] = {
"llm": ServiceType.LLM,
"tts": ServiceType.TTS,
@ -36,7 +36,7 @@ _SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
def enrich_overrides_with_api_keys(
model_overrides: dict,
user_config: UserConfiguration,
user_config: EffectiveAIModelConfiguration,
) -> dict:
"""Copy API keys from the global config into model_overrides where missing.
@ -74,9 +74,9 @@ def enrich_overrides_with_api_keys(
def resolve_effective_config(
user_config: UserConfiguration,
user_config: EffectiveAIModelConfiguration,
model_overrides: dict | None,
) -> UserConfiguration:
) -> EffectiveAIModelConfiguration:
"""Deep-merge workflow model_overrides onto global user config.
- If model_overrides is None or empty, returns a copy of user_config unchanged.

View file

@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
api_key: Optional[str] = None,
model_id: str = DEFAULT_MODEL_ID,
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
):
"""Initialize the OpenAI embedding service.
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
field_name="base_url",
)
client_kwargs["base_url"] = base_url
if default_headers:
client_kwargs["default_headers"] = default_headers
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
else:

View file

@ -0,0 +1,78 @@
from __future__ import annotations
from typing import Any
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
def uses_managed_model_services_v2(
ai_model_config: EffectiveAIModelConfiguration | None,
) -> bool:
if (
ai_model_config is None
or getattr(ai_model_config, "managed_service_version", None) != 2
):
return False
return any(
_is_dograh_service(getattr(ai_model_config, section_name, None))
for section_name in ("llm", "tts", "stt", "embeddings")
)
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
if not initial_context:
return None
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
if correlation_id is None:
return None
return str(correlation_id)
async def ensure_mps_correlation_id(
*,
ai_model_config: EffectiveAIModelConfiguration,
workflow_run_id: int,
initial_context: dict[str, Any] | None,
) -> str | None:
existing = get_mps_correlation_id(initial_context)
if existing:
return existing
if not uses_managed_model_services_v2(ai_model_config):
return None
raise ValueError(
"Managed model services v2 requires workflow run authorization before "
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
)
def _is_dograh_service(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):
service = getattr(ai_model_config, section_name, None)
if not _is_dograh_service(service):
continue
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if keys:
return keys[0]
api_key = getattr(service, "api_key", None)
if isinstance(api_key, str) and api_key:
return api_key
return None

View file

@ -0,0 +1,23 @@
from typing import Optional
from api.constants import DEPLOYMENT_MODE
from api.services.mps_service_key_client import mps_service_key_client
async def ensure_hosted_mps_billing_account_v2(
organization_id: int,
*,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Ensure hosted orgs have an MPS billing v2 account.
OSS deployments use legacy per-key quota accounting and do not create MPS
billing accounts.
"""
if DEPLOYMENT_MODE == "oss":
return None
return await mps_service_key_client.ensure_billing_account_v2(
organization_id=organization_id,
created_by=created_by,
)

View file

@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
Service keys are stored and managed entirely in MPS, not in the local database.
"""
import asyncio
from typing import List, Optional
import httpx
@ -353,6 +354,278 @@ class MPSServiceKeyClient:
response=response,
)
async def create_credit_purchase_url(
self,
organization_id: int,
created_by: Optional[str] = None,
return_url: Optional[str] = None,
billing_details: Optional[dict] = None,
) -> dict:
"""Create a short-lived MPS checkout URL for adding organization credits."""
payload = {
"created_by": created_by,
"return_url": return_url,
"billing_details": billing_details or {},
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create MPS credit purchase URL: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create MPS credit purchase URL: {response.text}",
request=response.request,
response=response,
)
async def get_credit_ledger(
self,
organization_id: int,
page: int = 1,
limit: int = 50,
created_by: Optional[str] = None,
) -> dict:
"""Get the MPS v2 billing account balance and recent credit ledger."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
params={"page": page, "limit": limit},
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS credit ledger: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS credit ledger: {response.text}",
request=response.request,
response=response,
)
async def get_billing_account_status(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> Optional[dict]:
"""Get an existing MPS v2 billing account without creating one."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to get MPS billing account status: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get MPS billing account status: {response.text}",
request=response.request,
response=response,
)
async def ensure_billing_account_v2(
self,
organization_id: int,
created_by: Optional[str] = None,
) -> dict:
"""Create or return the MPS v2 billing account for an organization."""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to ensure MPS billing account v2: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to ensure MPS billing account v2: {response.text}",
request=response.request,
response=response,
)
async def authorize_workflow_run_start(
self,
*,
organization_id: int,
workflow_run_id: int | None = None,
service_key: Optional[str] = None,
require_correlation_id: bool = False,
minimum_credits: float | None = None,
metadata: Optional[dict] = None,
created_by: Optional[str] = None,
) -> dict:
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
payload = {
"workflow_run_id": workflow_run_id,
"service_key": service_key,
"require_correlation_id": require_correlation_id,
"metadata": metadata or {},
}
if minimum_credits is not None:
payload["minimum_credits"] = minimum_credits
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to authorize MPS workflow run start: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to authorize MPS workflow run start: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,
service_key: str,
workflow_run_id: int | None = None,
) -> dict:
"""Mint a server-generated correlation ID for managed model services."""
payload: dict[str, int] = {}
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
json=payload,
headers={
"Authorization": f"Bearer {service_key}",
"Content-Type": "application/json",
},
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to create correlation ID: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to create correlation ID: {response.text}",
request=response.request,
response=response,
)
async def report_platform_usage(
self,
*,
organization_id: int,
correlation_id: Optional[str] = None,
duration_seconds: Optional[float] = None,
workflow_run_id: int | None = None,
metadata: Optional[dict] = None,
max_attempts: int = 3,
) -> dict:
"""Report hosted Dograh platform usage for a completed workflow run."""
if DEPLOYMENT_MODE == "oss":
raise ValueError("OSS deployments must not report platform usage to MPS")
if not correlation_id and duration_seconds is None:
raise ValueError(
"Platform usage reports require correlation_id or duration_seconds"
)
payload: dict = {
"metadata": metadata or {},
}
if correlation_id:
payload["correlation_id"] = correlation_id
if duration_seconds is not None:
payload["duration_seconds"] = duration_seconds
if workflow_run_id is not None:
payload["workflow_run_id"] = workflow_run_id
max_attempts = max(1, max_attempts)
last_response: httpx.Response | None = None
async with httpx.AsyncClient(timeout=self.timeout) as client:
for attempt in range(1, max_attempts + 1):
response = await client.post(
(
f"{self.base_url}/api/v1/billing/accounts/"
f"{organization_id}/platform-usage"
),
json=payload,
headers=self._get_headers(organization_id=organization_id),
)
last_response = response
if response.status_code == 200:
return response.json()
should_retry = (
response.status_code == 409
and "usage_not_ready" in response.text
and attempt < max_attempts
)
if should_retry:
await asyncio.sleep(attempt)
continue
logger.error(
"Failed to report platform usage: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to report platform usage: {response.text}",
request=response.request,
response=response,
)
raise httpx.HTTPStatusError(
"Failed to report platform usage",
request=last_response.request,
response=last_response,
)
async def transcribe_audio(
self,
audio_data: bytes,

View file

@ -0,0 +1,50 @@
from typing import Literal, Optional
from pydantic import BaseModel
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
class OrganizationModelServicesContext(BaseModel):
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
has_model_configuration_v2: bool
managed_service_version: Optional[int] = None
uses_managed_service_v2: bool
class OrganizationContextResponse(BaseModel):
organization_id: Optional[int] = None
organization_provider_id: Optional[str] = None
model_services: OrganizationModelServicesContext
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
organization_id = user.selected_organization_id
organization = (
await db_client.get_organization_by_id(organization_id)
if organization_id
else None
)
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=organization_id,
)
managed_service_version = resolved.effective.managed_service_version
return OrganizationContextResponse(
organization_id=organization_id,
organization_provider_id=organization.provider_id if organization else None,
model_services=OrganizationModelServicesContext(
config_source=resolved.source,
has_model_configuration_v2=resolved.source == "organization_v2",
managed_service_version=managed_service_version,
uses_managed_service_v2=(
resolved.source == "organization_v2" and managed_service_version == 2
),
),
)

View file

@ -0,0 +1,62 @@
from inspect import isawaitable
from loguru import logger
from pydantic import ValidationError
from api.db import db_client
from api.enums import OrganizationConfigurationKey
from api.schemas.organization_preferences import OrganizationPreferences
async def get_organization_preferences(
organization_id: int | None,
db=None,
) -> OrganizationPreferences:
if organization_id is None:
return OrganizationPreferences()
db = db or db_client
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
)
if row is None:
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
return _parse_preferences(row.value if row is not None else None, organization_id)
async def upsert_organization_preferences(
organization_id: int,
preferences: OrganizationPreferences,
) -> OrganizationPreferences:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
preferences.model_dump(mode="json", exclude_none=True),
)
return preferences
async def _get_configuration(db, organization_id: int, key: str):
row = db.get_configuration(organization_id, key)
if isawaitable(row):
row = await row
return row
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
if not value or not isinstance(value, dict):
return OrganizationPreferences()
try:
return OrganizationPreferences.model_validate(value)
except ValidationError as exc:
logger.warning(
"Invalid organization preferences for organization "
f"{organization_id}: {exc}. Returning defaults."
)
return OrganizationPreferences()

View file

@ -15,6 +15,29 @@ from api.utils.credential_auth import build_auth_header
PRE_CALL_FETCH_TIMEOUT_SECONDS = 10
def _extract_initial_context(response_data: Dict[str, Any]) -> Dict[str, Any]:
"""Pull the context variables out of a pre-call fetch response.
The canonical key is ``initial_context``. The legacy ``dynamic_variables``
key is still accepted for backward compatibility, so existing endpoints
keep working; ``initial_context`` takes precedence when both are present.
Either key may appear at the top level or nested under ``call_inbound``:
{"call_inbound": {"initial_context": {...}}} | {"initial_context": {...}}
{"call_inbound": {"dynamic_variables": {...}}} | {"dynamic_variables": {...}}
"""
container = response_data.get("call_inbound")
if not isinstance(container, dict):
container = response_data
for key in ("initial_context", "dynamic_variables"):
value = container.get(key)
if isinstance(value, dict):
return value
return {}
async def execute_pre_call_fetch(
*,
url: str,
@ -77,24 +100,16 @@ async def execute_pre_call_fetch(
)
return {}
# Extract dynamic_variables from Retell-compatible response
# Supports: {call_inbound: {dynamic_variables: {...}}}
# or: {dynamic_variables: {...}}
dynamic_vars = {}
call_inbound = response_data.get("call_inbound")
if isinstance(call_inbound, dict):
dynamic_vars = call_inbound.get("dynamic_variables", {})
elif "dynamic_variables" in response_data:
dynamic_vars = response_data["dynamic_variables"]
if not isinstance(dynamic_vars, dict):
dynamic_vars = {}
# Extract the variables to merge into initial_context. Prefers
# the canonical `initial_context` key, falling back to the
# legacy `dynamic_variables` key for backward compatibility.
initial_context_vars = _extract_initial_context(response_data)
logger.info(
f"Pre-call fetch: success ({response.status_code}), "
f"dynamic_variables keys: {list(dynamic_vars.keys())}"
f"initial_context keys: {list(initial_context_vars.keys())}"
)
return dynamic_vars
return initial_context_vars
else:
logger.warning(
f"Pre-call fetch: HTTP {response.status_code} - "

View file

@ -162,15 +162,13 @@ async def run_pipeline_telephony(
workflow_id: Workflow being executed.
workflow_run_id: Workflow run row.
user_id: Owner of the workflow.
call_id: Provider call identifier (stored in cost_info for billing).
call_id: Provider call identifier.
transport_kwargs: Provider-specific kwargs forwarded to the transport
factory (e.g. stream_sid + call_sid for Twilio).
"""
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
set_current_run_id(workflow_run_id)
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
workflow = await db_client.get_workflow(workflow_id, user_id)
if workflow:
set_current_org_id(workflow.organization_id)
@ -195,14 +193,17 @@ async def run_pipeline_telephony(
# Resolve effective user config here so the transport can tune its
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
# values into _run_pipeline so it doesn't fetch them again.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -272,15 +273,18 @@ async def run_pipeline_smallwebrtc(
# Resolve workflow_run + effective user_config here so the transport can
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
# reuses these via kwargs so we don't fetch twice.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -334,7 +338,7 @@ async def _run_pipeline(
if workflow_run.is_completed:
raise HTTPException(status_code=400, detail="Workflow run already completed")
merged_call_context_vars = workflow_run.initial_context
merged_call_context_vars = dict(workflow_run.initial_context or {})
# If there is some extra call_context_vars, fold them in. Persistence
# happens once below, after runtime_configuration is also resolved.
if call_context_vars:
@ -380,15 +384,31 @@ async def _run_pipeline(
# Resolve model overrides from the version onto global user config (skip
# when the caller already resolved it).
if resolved_user_config is None:
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
else:
user_config = resolved_user_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=merged_call_context_vars,
)
if mps_correlation_id:
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
is_realtime = user_config.is_realtime and user_config.realtime is not None
@ -400,11 +420,23 @@ async def _run_pipeline(
# Realtime services don't implement run_inference, so create a
# separate text LLM for variable extraction and other out-of-band
# inference calls.
inference_llm = create_llm_service(user_config)
inference_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
tts = create_tts_service(user_config, audio_config)
llm = create_llm_service(user_config)
stt = create_stt_service(
user_config,
audio_config,
keyterms=keyterms,
correlation_id=mps_correlation_id,
)
tts = create_tts_service(
user_config,
audio_config,
correlation_id=mps_correlation_id,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = None
# Stamp the providers/models actually resolved for this run onto
@ -508,10 +540,17 @@ async def _run_pipeline(
embeddings_endpoint = None
embeddings_api_version = None
if user_config and user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
@ -679,7 +718,10 @@ async def _run_pipeline(
# Create a separate LLM instance for the voicemail sub-pipeline
# (can't share with main pipeline as it would mess up frame linking)
if voicemail_config.get("use_workflow_llm", True):
voicemail_llm = create_llm_service(user_config)
voicemail_llm = create_llm_service(
user_config,
correlation_id=mps_correlation_id,
)
else:
voicemail_llm = create_llm_service_from_provider(
provider=voicemail_config.get("provider", "openai"),

View file

@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
user_config,
audio_config: "AudioConfig",
keyterms: list[str] | None = None,
correlation_id: str | None = None,
):
"""Create and return appropriate STT service based on user configuration
@ -160,6 +163,7 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
correlation_id=correlation_id,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
@ -286,7 +290,9 @@ def create_stt_service(
)
def create_tts_service(user_config, audio_config: "AudioConfig"):
def create_tts_service(
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
):
"""Create and return appropriate TTS service based on user configuration
Args:
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
correlation_id=correlation_id,
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
model: str,
api_key: str | None,
*,
correlation_id: str | None = None,
base_url: str | None = None,
endpoint: str | None = None,
aws_access_key: str | None = None,
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=api_key,
correlation_id=correlation_id,
settings=OpenAILLMSettings(model=model),
)
elif provider == ServiceProviders.AWS_BEDROCK.value:
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
)
def create_llm_service(user_config):
def create_llm_service(user_config, correlation_id: str | None = None):
"""Create and return appropriate LLM service based on user configuration."""
provider = user_config.llm.provider
model = user_config.llm.model
@ -880,4 +889,10 @@ def create_llm_service(user_config):
elif provider == ServiceProviders.SARVAM.value:
kwargs["temperature"] = user_config.llm.temperature
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
return create_llm_service_from_provider(
provider,
model,
api_key,
correlation_id=correlation_id,
**kwargs,
)

View file

@ -1,76 +0,0 @@
# Pricing Module
This module contains pricing models and registries for different AI services used in workflow cost calculations.
## Structure
```
pricing/
├── __init__.py # Main module exports
├── models.py # Base pricing model classes
├── llm.py # LLM pricing configurations
├── tts.py # TTS pricing configurations
├── stt.py # STT pricing configurations
├── registry.py # Combined pricing registry
└── README.md # This file
```
## Pricing Models
### TokenPricingModel
Used for LLM services that charge based on tokens:
- `prompt_token_price`: Cost per prompt token
- `completion_token_price`: Cost per completion token
- `cache_read_discount`: Discount for cache read tokens (default 50%)
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
### CharacterPricingModel
Used for TTS services that charge based on character count:
- `character_price`: Cost per character
### TimePricingModel
Used for STT services that charge based on time:
- `second_price`: Cost per second
## Adding New Pricing
### Adding a New LLM Model
Edit `llm.py` and add the model to the appropriate provider:
```python
ServiceProviders.OPENAI: {
"new-model": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000,
completion_token_price=Decimal("8.00") / 1000000,
),
# ... existing models
}
```
### Adding a New Provider
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
2. The registry will automatically include them
### Adding a New Service Type
1. Create a new pricing file (e.g., `image.py`)
2. Define the pricing models
3. Import and add to `registry.py`
## Usage
The pricing registry is automatically imported and used by the cost calculator:
```python
from api.services.pricing import PRICING_REGISTRY
from api.services.workflow.cost_calculator import cost_calculator
# The cost calculator uses the pricing registry automatically
result = cost_calculator.calculate_total_cost(usage_info)
```
## Maintenance
- Update pricing when providers change their rates
- All prices should use `Decimal` for precision
- Include comments with current pricing from provider documentation
- Test changes with existing test suite

View file

@ -1,9 +0,0 @@
"""
Pricing module for workflow cost calculation.
This module contains pricing models and registries for different AI services.
"""
from .registry import PRICING_REGISTRY
__all__ = ["PRICING_REGISTRY"]

View file

@ -1,228 +0,0 @@
"""
Cost Calculator for Workflow Runs
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
Features:
- Token-based pricing for LLM services with cache optimization support
- Character-based pricing for TTS services
- Time-based pricing for STT services
- Configurable pricing models that can be updated
- Support for multiple providers and models
- Automatic provider inference from model names
- JSON serialization support for database storage
Usage:
from api.tasks.cost_calculator import cost_calculator
usage_info = {
"llm": {
"processor_name|||gpt-4o": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"total_tokens": 1500,
"cache_read_input_tokens": 0,
"cache_creation_input_tokens": 0
}
},
"tts": {
"processor_name|||aura-2-helena-en": 2000 # character count
}
}
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
print(f"Total cost: ${cost_breakdown['total']:.6f}")
"""
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
from api.services.configuration.registry import ServiceProviders
from api.services.pricing import PRICING_REGISTRY
from api.services.pricing.models import (
PricingModel,
)
class CostCalculator:
"""Main cost calculator class"""
def __init__(self, pricing_registry: Dict = None):
self.pricing_registry = pricing_registry or PRICING_REGISTRY
def get_pricing_model(
self, service_type: str, provider: str, model: str
) -> Optional[PricingModel]:
"""Get pricing model for a specific service, provider, and model"""
try:
service_pricing = self.pricing_registry.get(service_type, {})
# Try to get pricing for the specific provider
provider_pricing = service_pricing.get(provider, {})
pricing_model = provider_pricing.get(model) or provider_pricing.get(
"default"
)
if pricing_model:
return pricing_model
# If not found, try the "default" provider for this service type
default_provider_pricing = service_pricing.get("default", {})
return default_provider_pricing.get(model) or default_provider_pricing.get(
"default"
)
except (KeyError, AttributeError):
return None
def calculate_llm_cost(
self, provider: str, model: str, usage: Dict[str, int]
) -> Decimal:
"""Calculate cost for LLM usage"""
pricing_model = self.get_pricing_model("llm", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(usage)
def calculate_tts_cost(
self, provider: str, model: str, character_count: int
) -> Decimal:
"""Calculate cost for TTS usage"""
pricing_model = self.get_pricing_model("tts", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(character_count)
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
"""Calculate cost for STT usage"""
pricing_model = self.get_pricing_model("stt", provider, model)
if not pricing_model:
return Decimal("0")
return pricing_model.calculate_cost(seconds)
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
llm_cost_total = Decimal("0")
tts_cost_total = Decimal("0")
stt_cost_total = Decimal("0")
# Calculate LLM costs
llm_usage = usage_info.get("llm", {})
for key, usage in llm_usage.items():
processor, model = self._parse_key(key)
# Try to determine provider from processor name or model
provider = self._infer_provider_from_model(model, "llm")
cost = self.calculate_llm_cost(provider, model, usage)
llm_cost_total += cost
# Calculate TTS costs
tts_usage = usage_info.get("tts", {})
for key, character_count in tts_usage.items():
processor, model = self._parse_key(key)
# Handle the case where model is "None" - infer from processor
if model.lower() in ["none", "null", ""]:
provider = self._infer_provider_from_processor(processor, "tts")
model = "default" # Use default model for the provider
else:
provider = self._infer_provider_from_model(model, "tts")
cost = self.calculate_tts_cost(provider, model, character_count)
tts_cost_total += cost
# Calculate STT costs from explicit stt usage
stt_usage = usage_info.get("stt", {})
for key, seconds in stt_usage.items():
processor, model = self._parse_key(key)
provider = self._infer_provider_from_model(model, "stt")
cost = self.calculate_stt_cost(provider, model, seconds)
stt_cost_total += cost
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
return {
"llm_cost": float(llm_cost_total),
"tts_cost": float(tts_cost_total),
"stt_cost": float(stt_cost_total),
"total": float(total_cost),
}
def _parse_key(self, key) -> Tuple[str, str]:
"""Parse key which is in format 'processor|||model'"""
if isinstance(key, str) and "|||" in key:
parts = key.split("|||", 1)
return parts[0], parts[1]
else:
# Fallback for backwards compatibility or malformed keys
return str(key), "unknown"
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
"""Infer provider from model name"""
if not model:
return "unknown"
model_lower = model.lower()
# OpenAI models
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
return ServiceProviders.OPENAI
# Groq models
if any(keyword in model_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Elevenlabs models
if any(keyword in model_lower for keyword in ["eleven"]):
return ServiceProviders.ELEVENLABS
# Deepgram models
if any(
keyword in model_lower
for keyword in ["deepgram", "nova", "phonecall", "general"]
):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
"""Infer provider from processor name"""
if not processor:
return "unknown"
processor_lower = processor.lower()
# OpenAI processors
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
return ServiceProviders.OPENAI
# Groq processors
if any(keyword in processor_lower for keyword in ["groq"]):
return ServiceProviders.GROQ
# Deepgram processors
if any(keyword in processor_lower for keyword in ["deepgram"]):
return ServiceProviders.DEEPGRAM
# Default to first available provider for the service type
service_providers = self.pricing_registry.get(service_type, {})
if service_providers:
return list(service_providers.keys())[0]
return "unknown"
def update_pricing(
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
):
"""Update pricing for a specific service/provider/model combination"""
if service_type not in self.pricing_registry:
self.pricing_registry[service_type] = {}
if provider not in self.pricing_registry[service_type]:
self.pricing_registry[service_type][provider] = {}
self.pricing_registry[service_type][provider][model] = pricing_model
# Global cost calculator instance
cost_calculator = CostCalculator()

View file

@ -1,44 +0,0 @@
"""
Embeddings pricing models for different providers.
Prices are per token for embedding models.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import PricingModel
class EmbeddingPricingModel(PricingModel):
"""Pricing model for token-based embedding services."""
def __init__(self, token_price: Decimal):
"""Initialize with price per token.
Args:
token_price: Cost per token for embedding
"""
self.token_price = token_price
def calculate_cost(self, token_count: int) -> Decimal:
"""Calculate cost for embedding token usage."""
return Decimal(token_count) * self.token_price
# Embeddings pricing registry
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
ServiceProviders.OPENAI: {
"text-embedding-3-small": EmbeddingPricingModel(
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
),
"text-embedding-3-large": EmbeddingPricingModel(
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
),
"text-embedding-ada-002": EmbeddingPricingModel(
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
),
},
}

View file

@ -1,143 +0,0 @@
"""
LLM pricing models for different providers.
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TokenPricingModel
# LLM pricing registry
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-3.5-turbo": TokenPricingModel(
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
),
"gpt-4": TokenPricingModel(
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
),
"gpt-4.1": TokenPricingModel(
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
),
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
),
"gpt-4.1-nano": TokenPricingModel(
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
),
"gpt-4.5-preview": TokenPricingModel(
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
),
"gpt-4o": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
completion_token_price=Decimal("10.00")
/ 1000000, # $10.00 per 1M tokens - FIXED
),
"gpt-4o-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
),
"gpt-4o-mini": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-audio-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"gpt-4o-mini-realtime-preview": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
),
"gpt-4o-search-preview": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-search-preview": TokenPricingModel(
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
),
"o1": TokenPricingModel(
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
),
"o1-pro": TokenPricingModel(
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
),
"o1-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o3": TokenPricingModel(
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
),
"o3-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"o4-mini": TokenPricingModel(
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
),
"computer-use-preview": TokenPricingModel(
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
),
"gpt-image-1": TokenPricingModel(
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
),
"codex-mini-latest": TokenPricingModel(
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
),
# Transcription models
"gpt-4o-transcribe": TokenPricingModel(
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
),
"gpt-4o-mini-transcribe": TokenPricingModel(
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
),
# TTS models with token-based pricing
"gpt-4o-mini-tts": TokenPricingModel(
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
completion_token_price=Decimal("0")
/ 1000000, # No completion tokens for TTS
),
},
ServiceProviders.GROQ: {
"llama-3.3-70b-versatile": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
),
"deepseek-r1-distill-llama-70b": TokenPricingModel(
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
completion_token_price=Decimal("0.00079") / 1000,
),
},
ServiceProviders.AZURE: {
"gpt-4.1-mini": TokenPricingModel(
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
completion_token_price=Decimal("8.80")
/ 1000000, # $1.60 per 1M tokens if using data zone
)
},
}

View file

@ -1,89 +0,0 @@
"""
Base pricing models for different service types.
"""
from decimal import Decimal
from enum import Enum
from typing import Any, Dict
class CostType(Enum):
LLM_TOKENS = "llm_tokens"
TTS_CHARACTERS = "tts_characters"
STT_SECONDS = "stt_seconds"
class PricingModel:
"""Base class for pricing models"""
def calculate_cost(self, usage: Any) -> Decimal:
"""Calculate cost based on usage"""
raise NotImplementedError
class TokenPricingModel(PricingModel):
"""Pricing model for token-based services (LLM)"""
def __init__(
self,
prompt_token_price: Decimal,
completion_token_price: Decimal,
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
cache_creation_multiplier: Decimal = Decimal(
"1.25"
), # 25% premium for cache creation
):
self.prompt_token_price = prompt_token_price
self.completion_token_price = completion_token_price
self.cache_read_discount = cache_read_discount
self.cache_creation_multiplier = cache_creation_multiplier
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
"""Calculate cost for LLM token usage"""
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
# Base cost
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
completion_cost = Decimal(completion_tokens) * self.completion_token_price
# Cache adjustments
cache_read_savings = (
Decimal(cache_read_tokens)
* self.prompt_token_price
* self.cache_read_discount
)
cache_creation_premium = (
Decimal(cache_creation_tokens)
* self.prompt_token_price
* (self.cache_creation_multiplier - 1)
)
total_cost = (
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
)
return max(total_cost, Decimal("0")) # Ensure non-negative
class CharacterPricingModel(PricingModel):
"""Pricing model for character-based services (TTS)"""
def __init__(self, character_price: Decimal):
self.character_price = character_price
def calculate_cost(self, character_count: int) -> Decimal:
"""Calculate cost for TTS character usage"""
return Decimal(character_count) * self.character_price
class TimePricingModel(PricingModel):
"""Pricing model for time-based services (STT)"""
def __init__(self, second_price: Decimal):
self.second_price = second_price
def calculate_cost(self, seconds: float) -> Decimal:
"""Calculate cost for STT time usage"""
return Decimal(str(seconds)) * self.second_price

View file

@ -1,18 +0,0 @@
"""
Main pricing registry that combines all service type pricing models.
"""
from typing import Dict
from .embeddings import EMBEDDINGS_PRICING
from .llm import LLM_PRICING
from .stt import STT_PRICING
from .tts import TTS_PRICING
# Combined pricing registry for all service types
PRICING_REGISTRY: Dict = {
"llm": LLM_PRICING,
"tts": TTS_PRICING,
"stt": STT_PRICING,
"embeddings": EMBEDDINGS_PRICING,
}

View file

@ -1,13 +0,0 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}

View file

@ -1,26 +0,0 @@
"""
STT (Speech-to-Text) pricing models for different providers.
Prices are per second for STT services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import TimePricingModel
# STT pricing registry
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
ServiceProviders.DEEPGRAM: {
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
"default": TimePricingModel(Decimal("0.0077") / 60),
},
ServiceProviders.OPENAI: {
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
"default": TimePricingModel(Decimal("0.015") / 60),
},
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
}

View file

@ -1,30 +0,0 @@
"""
TTS (Text-to-Speech) pricing models for different providers.
Prices are per character for TTS services.
"""
from decimal import Decimal
from typing import Dict
from api.services.configuration.registry import ServiceProviders
from .models import CharacterPricingModel
# TTS pricing registry
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
ServiceProviders.OPENAI: {
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
},
ServiceProviders.DEEPGRAM: {
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
},
ServiceProviders.ELEVENLABS: {
# 6400 usd per 250*1e6 characters
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
},
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
}

View file

@ -1,230 +0,0 @@
from decimal import Decimal
from loguru import logger
from api.db import db_client
from api.enums import WorkflowRunMode
from api.services.pricing.cost_calculator import cost_calculator
from api.services.telephony.factory import get_telephony_provider_for_run
async def _fetch_telephony_cost(workflow_run) -> dict | None:
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
if (
workflow_run.mode
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
or not workflow_run.cost_info
):
return None
call_id = workflow_run.cost_info.get("call_id")
if not call_id:
logger.warning(f"call_id not found in cost_info")
return None
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning("Workflow not found for workflow run")
raise Exception("Workflow not found")
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
call_cost_info = await provider.get_call_cost(call_id)
if call_cost_info.get("status") == "error":
logger.error(
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
)
return None
cost_usd = call_cost_info.get("cost_usd", 0.0)
logger.info(
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
)
return {"cost_usd": cost_usd, "provider_name": provider_name}
async def _update_organization_usage(
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
) -> None:
"""Update organization usage after a workflow run."""
org_id = org.id
await db_client.update_usage_after_run(
org_id, dograh_tokens, duration_seconds, charge_usd
)
if charge_usd is not None:
logger.info(
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
)
else:
logger.info(
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
)
async def _get_pricing_organization(workflow_run):
workflow = getattr(workflow_run, "workflow", None)
organization_id = getattr(workflow, "organization_id", None)
if organization_id is None and workflow and workflow.user:
organization_id = workflow.user.selected_organization_id
if organization_id is None:
return None
return await db_client.get_organization_by_id(organization_id)
async def _build_usage_cost_snapshot(
usage_info: dict | None,
*,
workflow_run=None,
include_telephony_cost: bool = False,
organization=None,
calculated_at: str | None = None,
) -> dict | None:
if not usage_info:
logger.warning("No usage info available for workflow run")
return None
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
if include_telephony_cost and workflow_run is not None:
try:
telephony_cost = await _fetch_telephony_cost(workflow_run)
if telephony_cost:
telephony_cost_usd = telephony_cost["cost_usd"]
provider_name = telephony_cost["provider_name"]
cost_breakdown["telephony_call"] = telephony_cost_usd
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
cost_breakdown["total"] = (
float(cost_breakdown["total"]) + telephony_cost_usd
)
except Exception as e:
logger.error(f"Failed to fetch telephony call cost: {e}")
# Don't fail the whole cost calculation if telephony API fails
total_cost_usd = Decimal(str(cost_breakdown["total"]))
dograh_tokens = float(total_cost_usd * Decimal("100"))
if organization is None and workflow_run is not None:
organization = await _get_pricing_organization(workflow_run)
charge_usd = None
if organization and organization.price_per_second_usd:
duration_seconds = usage_info.get("call_duration_seconds", 0)
charge_usd = float(
Decimal(str(duration_seconds))
* Decimal(str(organization.price_per_second_usd))
)
cost_info = {
"cost_breakdown": cost_breakdown,
"total_cost_usd": float(total_cost_usd),
"dograh_token_usage": dograh_tokens,
"calculated_at": calculated_at
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
}
if charge_usd is not None:
cost_info["charge_usd"] = charge_usd
cost_info["price_per_second_usd"] = organization.price_per_second_usd
return cost_info
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
cost_info = await _build_usage_cost_snapshot(
workflow_run.usage_info,
workflow_run=workflow_run,
include_telephony_cost=True,
calculated_at=workflow_run.created_at.isoformat(),
)
if cost_info is None:
return None
return {
**(workflow_run.cost_info or {}),
**cost_info,
}
async def save_workflow_run_cost_info(
workflow_run_id: int, cost_info: dict | None
) -> None:
if cost_info is None:
return
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
async def apply_workflow_run_usage_to_organization(
workflow_run, cost_info: dict | None
) -> None:
if cost_info is None:
return
org = await _get_pricing_organization(workflow_run)
if not org:
return
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
async def apply_usage_delta_to_organization(
workflow_run, usage_info: dict | None
) -> dict | None:
org = await _get_pricing_organization(workflow_run)
if not org:
return None
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
if cost_info is None:
return None
await _update_organization_usage(
org,
float(cost_info.get("dograh_token_usage") or 0),
float(cost_info.get("call_duration_seconds") or 0),
cost_info.get("charge_usd"),
)
return cost_info
async def calculate_workflow_run_cost(workflow_run_id: int):
logger.debug("Calculating cost for workflow run")
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning("Workflow run not found")
return
try:
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is None:
return
await save_workflow_run_cost_info(workflow_run_id, cost_info)
try:
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
except Exception as e:
org = await _get_pricing_organization(workflow_run)
if org:
logger.error(
f"Failed to update organization usage for org {org.id}: {e}"
)
else:
logger.error(f"Failed to update organization usage: {e}")
# Don't fail the whole cost calculation if usage update fails
logger.info(
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
)
except Exception as e:
logger.error(f"Error calculating cost for workflow run: {e}")
raise

View file

@ -5,15 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
"""
from dataclasses import dataclass
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
get_dograh_service_api_key,
uses_managed_model_services_v2,
)
from api.services.mps_service_key_client import mps_service_key_client
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
)
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your Dograh credits. "
"Please purchase more credits from /billing "
"or change providers in Models configurations."
)
@dataclass
class QuotaCheckResult:
@ -24,104 +47,359 @@ class QuotaCheckResult:
error_code: str = ""
async def check_dograh_quota(
user: UserModel, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check if user has sufficient Dograh quota for making a call.
This function checks if the user is using any Dograh services (LLM, STT, TTS)
and validates that they have sufficient credits remaining.
When ``workflow_id`` is provided, the workflow's per-workflow
``model_overrides`` are merged onto the user's global config so the quota
check runs against the credentials that will actually be used for the call
(rather than always falling back to the user's defaults).
Args:
user: The user to check quota for
workflow_id: Optional workflow whose ``model_overrides`` should be
applied when resolving the effective service config.
Returns:
QuotaCheckResult with has_quota=True if user has sufficient quota or
is not using Dograh services, or has_quota=False with error_message
if quota is insufficient.
"""
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
# Get user configurations
user_config = await db_client.get_user_configurations(user.id)
return float(value)
except (TypeError, ValueError):
return default
if workflow_id is not None:
workflow = await db_client.get_workflow_by_id(workflow_id)
if workflow:
model_overrides = (workflow.workflow_configurations or {}).get(
"model_overrides"
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="insufficient_credits",
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
)
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
)
def _service_uses_dograh(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _dograh_api_keys(user_config: Any) -> set[str]:
api_keys: set[str] = set()
for section_name in ("llm", "stt", "tts", "embeddings"):
service = getattr(user_config, section_name, None)
if not _service_uses_dograh(service):
continue
if hasattr(service, "get_all_api_keys"):
all_api_keys = [
api_key
for api_key in service.get_all_api_keys()
if isinstance(api_key, str) and api_key
]
if all_api_keys:
api_keys.update(all_api_keys)
continue
api_key = getattr(service, "api_key", None)
if api_key:
api_keys.add(api_key)
return api_keys
async def _store_run_correlation_id(
workflow_run_id: int | None,
correlation_id: str | None,
) -> None:
if not workflow_run_id or not correlation_id:
return
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Could not store MPS correlation id for missing workflow run {}",
workflow_run_id,
)
return
initial_context = dict(workflow_run.initial_context or {})
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
return
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
async def _authorize_hosted_workflow_run_start(
*,
workflow_owner: UserModel,
organization_id: int | None,
workflow_id: int | None,
workflow_run_id: int | None,
user_config: Any,
) -> tuple[QuotaCheckResult, bool]:
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
if DEPLOYMENT_MODE == "oss" or organization_id is None:
return QuotaCheckResult(has_quota=True), False
requires_correlation = bool(
workflow_run_id and uses_managed_model_services_v2(user_config)
)
service_key = (
get_dograh_service_api_key(user_config) if requires_correlation else None
)
if requires_correlation and not service_key:
return (
QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
),
True,
)
try:
authorization = await mps_service_key_client.authorize_workflow_run_start(
organization_id=organization_id,
workflow_run_id=workflow_run_id,
service_key=service_key,
require_correlation_id=requires_correlation,
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by=(
str(workflow_owner.provider_id)
if workflow_owner.provider_id is not None
else None
),
metadata={
"dograh_user_id": str(workflow_owner.id),
"workflow_id": workflow_id,
},
)
except Exception as e:
logger.error(
"Failed to authorize workflow start with MPS for org {}: {}",
organization_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
billing_mode = authorization.get("billing_mode")
if billing_mode != "v2":
return QuotaCheckResult(has_quota=True), False
remaining = _safe_float(authorization.get("remaining_credits"))
if (
not authorization.get("allowed", False)
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
):
logger.warning(
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return _insufficient_billing_v2_quota_result(), True
try:
await _store_run_correlation_id(
workflow_run_id,
authorization.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to store MPS correlation id for workflow_run_id {}: {}",
workflow_run_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
logger.info(
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return QuotaCheckResult(has_quota=True), True
async def _authorize_legacy_dograh_keys(
*,
dograh_api_keys: set[str],
organization_id: int | None,
workflow_owner: UserModel,
) -> QuotaCheckResult:
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key,
organization_id=organization_id,
created_by=workflow_owner.provider_id,
)
remaining = usage.get("remaining_credits", 0.0)
# Require at least $0.10 for a short call
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
if model_overrides:
user_config = resolve_effective_config(user_config, model_overrides)
return _insufficient_legacy_quota_result()
# Check if user is using any Dograh service
using_dograh = False
dograh_api_keys = set()
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.llm.api_key)
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.stt.api_key)
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return QuotaCheckResult(has_quota=True)
# Check quota for ALL Dograh keys
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key, created_by=user.provider_id
)
remaining = usage.get("remaining_credits", 0.0)
# Require at least $0.10 for a short call
if remaining < 0.10:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=(
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
),
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def _authorize_oss_managed_v2_correlation(
*,
workflow_id: int,
workflow_run_id: int | None,
user_config: Any,
) -> QuotaCheckResult:
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
return QuotaCheckResult(has_quota=True)
service_key = get_dograh_service_api_key(user_config)
if not service_key:
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
)
try:
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
await _store_run_correlation_id(
workflow_run_id,
response.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
workflow_id,
workflow_run_id,
e,
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def authorize_workflow_run_start(
*,
workflow_id: int,
workflow_run_id: int | None = None,
actor_user: UserModel | None = None,
) -> QuotaCheckResult:
"""Authorize a workflow run before any billable call/text runtime starts.
The workflow organization is the billing subject for hosted v2. The workflow
owner is used only to resolve the effective model configuration and legacy
service-key metadata.
"""
try:
workflow = await db_client.get_workflow_by_id(workflow_id)
if not workflow:
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
actor_org_id = getattr(actor_user, "selected_organization_id", None)
if actor_org_id is not None and actor_org_id != workflow.organization_id:
logger.warning(
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
actor_org_id,
workflow_id,
workflow.organization_id,
)
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
if not workflow_owner:
return QuotaCheckResult(
has_quota=False,
error_code="user_not_found",
error_message="User not found",
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_owner.id,
organization_id=workflow.organization_id,
workflow_configurations=workflow.workflow_configurations,
)
if DEPLOYMENT_MODE != "oss":
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
workflow_owner=workflow_owner,
organization_id=workflow.organization_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
if hosted_enforced or not hosted_result.has_quota:
return hosted_result
dograh_api_keys = _dograh_api_keys(user_config)
if not dograh_api_keys:
return QuotaCheckResult(has_quota=True)
legacy_result = await _authorize_legacy_dograh_keys(
dograh_api_keys=dograh_api_keys,
organization_id=(
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
),
workflow_owner=workflow_owner,
)
if not legacy_result.has_quota:
return legacy_result
if DEPLOYMENT_MODE == "oss":
return await _authorize_oss_managed_v2_correlation(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
return QuotaCheckResult(has_quota=True)
@ -129,30 +407,3 @@ async def check_dograh_quota(
logger.error(f"Error during quota check: {str(e)}")
# On unexpected error, allow the call to proceed
return QuotaCheckResult(has_quota=True)
async def check_dograh_quota_by_user_id(
user_id: int, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check Dograh quota by user ID.
Convenience function that fetches the user and then checks quota. When
``workflow_id`` is provided, the workflow's ``model_overrides`` are
applied so the quota check evaluates the credentials that will actually
be used for the call.
Args:
user_id: The ID of the user to check quota for
workflow_id: Optional workflow whose per-workflow overrides should
be applied to the user's config before checking quota.
Returns:
QuotaCheckResult with quota status
"""
user = await db_client.get_user_by_id(user_id)
if not user:
return QuotaCheckResult(
has_quota=False,
error_message="User not found",
)
return await check_dograh_quota(user, workflow_id=workflow_id)

View file

@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
for run in runs:
initial = run.initial_context or {}
gathered = run.gathered_context or {}
cost = run.cost_info or {}
usage = run.usage_info or {}
call_tags = gathered.get("call_tags", [])
if isinstance(call_tags, list):
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
run.created_at.isoformat() if run.created_at else "",
initial.get("phone_number", ""),
gathered.get("mapped_call_disposition", ""),
cost.get("call_duration_seconds", ""),
usage.get("call_duration_seconds", ""),
]
extracted = gathered.get("extracted_variables", {})

View file

@ -26,7 +26,7 @@ from loguru import logger
from api.constants import REDIS_URL
from api.db import db_client
from api.enums import CallType, WorkflowRunMode
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.transfer_event_protocol import (
TransferEvent,
@ -564,19 +564,7 @@ class ARIConnection:
user_id = workflow.user_id
# 3. Check quota (apply per-workflow model_overrides).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=inbound_workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 4. Create workflow run
# 3. Create workflow run
call_id = channel_id
workflow_run = await db_client.create_workflow_run(
name=f"ARI Inbound {caller_number}",
@ -602,6 +590,20 @@ class ARIConnection:
f"(caller={caller_number}, called={called_number})"
)
# 4. Check quota after the run exists so hosted v2 can mint and
# store the MPS correlation id before the pipeline starts.
quota_result = await authorize_workflow_run_start(
workflow_id=inbound_workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 5. Answer the inbound channel
await self._answer_channel(channel_id)

View file

@ -103,7 +103,8 @@ async def handle_cloudonix_cdr(request: Request):
return {"status": "error", "message": "Missing domain field"}
# Extract call_id to find workflow run
call_id = cdr_data.get("session").get("token")
session = cdr_data.get("session")
call_id = session.get("token") if isinstance(session, dict) else None
logger.info(f"Cloudonix CDR data for call id {call_id} - {cdr_data}")
if not call_id:
logger.warning("Cloudonix CDR missing call_id field")

View file

@ -6,9 +6,8 @@ provider registry — see ProviderSpec.router.
import json
from datetime import UTC, datetime
from typing import Optional
from fastapi import APIRouter, Header, Request
from fastapi import APIRouter, HTTPException, Request
from loguru import logger
from pipecat.utils.run_context import set_current_run_id
from starlette.responses import HTMLResponse
@ -29,6 +28,30 @@ from api.utils.telephony_helper import (
router = APIRouter()
async def _verify_vobiz_callback(
provider,
webhook_url: str,
callback_data: dict,
headers: dict,
raw_body: str,
*,
log_prefix: str,
) -> None:
"""Verify a Vobiz callback signature, failing closed.
Vobiz signs every callback, so a missing signature header is an invalid
request ``provider.verify_inbound_signature`` returns ``False`` for both
missing and forged signatures. Reject with HTTP 403 (per Vobiz's
callback-validation docs) so the caller never reaches status processing.
"""
is_valid = await provider.verify_inbound_signature(
webhook_url, callback_data, headers, raw_body
)
if not is_valid:
logger.warning(f"{log_prefix} Invalid or missing Vobiz callback signature")
raise HTTPException(status_code=403, detail="Invalid webhook signature")
@router.post("/vobiz-xml", include_in_schema=False)
async def handle_vobiz_xml_webhook(
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
@ -65,8 +88,6 @@ async def handle_vobiz_xml_webhook(
async def handle_vobiz_hangup_callback(
workflow_run_id: int,
request: Request,
x_vobiz_signature: Optional[str] = Header(None),
x_vobiz_timestamp: Optional[str] = Header(None),
):
"""Handle Vobiz hangup callback (sent when call ends).
@ -75,82 +96,23 @@ async def handle_vobiz_hangup_callback(
"""
set_current_run_id(workflow_run_id)
# Logging all headers and body to understand what Vobiz actually sends
all_headers = dict(request.headers)
logger.info(
f"[run {workflow_run_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
)
# Parse the callback data from the raw body so signed webhooks can verify
# the exact bytes Vobiz sent without draining the request stream first.
callback_data, raw_body = await parse_webhook_request(request)
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
logger.info(
f"[run {workflow_run_id}] Vobiz hangup callback - Body: {json.dumps(callback_data)}"
)
logger.info(
f"[run {workflow_run_id}] Received Vobiz hangup callback {json.dumps(callback_data)}"
)
# Verify signature if Vobiz provided any supported signature header.
has_vobiz_signature = any(
header in all_headers
for header in (
"x-vobiz-signature-v3",
"x-vobiz-signature-ma-v3",
"x-vobiz-signature-v2",
"x-vobiz-signature-ma-v2",
)
)
if has_vobiz_signature:
# We need the workflow run to get organization for provider credentials
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for signature verification"
)
return {"status": "error", "reason": "workflow_run_not_found"}
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(
f"[run {workflow_run_id}] Workflow not found for signature verification"
)
return {"status": "error", "reason": "workflow_not_found"}
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
# Verify signature
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
is_valid = await provider.verify_inbound_signature(
webhook_url,
callback_data,
all_headers,
raw_body,
)
if not is_valid:
logger.warning(
f"[run {workflow_run_id}] Invalid Vobiz hangup callback signature"
)
return {"status": "error", "reason": "invalid_signature"}
logger.info(f"[run {workflow_run_id}] Vobiz hangup callback signature verified")
else:
# Get workflow run for processing (signature verification already got it if needed)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for Vobiz hangup callback"
)
return {"status": "ignored", "reason": "workflow_run_not_found"}
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(f"[run {workflow_run_id}] Workflow not found")
@ -160,6 +122,21 @@ async def handle_vobiz_hangup_callback(
workflow_run, workflow.organization_id
)
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
# before they can mutate call state.
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = (
f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
)
await _verify_vobiz_callback(
provider,
webhook_url,
callback_data,
all_headers,
raw_body,
log_prefix=f"[run {workflow_run_id}]",
)
logger.debug(
f"[run {workflow_run_id}] Processing Vobiz hangup with provider: {provider.PROVIDER_NAME}"
)
@ -167,10 +144,6 @@ async def handle_vobiz_hangup_callback(
# Parse the callback data into generic format
parsed_data = provider.parse_status_callback(callback_data)
logger.debug(
f"[run {workflow_run_id}] Parsed Vobiz callback data: {json.dumps(parsed_data)}"
)
# Create StatusCallbackRequest from parsed data
status_update = StatusCallbackRequest(
call_id=parsed_data["call_id"],
@ -194,8 +167,6 @@ async def handle_vobiz_hangup_callback(
async def handle_vobiz_ring_callback(
workflow_run_id: int,
request: Request,
x_vobiz_signature: Optional[str] = Header(None),
x_vobiz_timestamp: Optional[str] = Header(None),
):
"""Handle Vobiz ring callback (sent when call starts ringing).
@ -204,84 +175,46 @@ async def handle_vobiz_ring_callback(
"""
set_current_run_id(workflow_run_id)
# Logging all headers and body to understand what Vobiz actually sends
all_headers = dict(request.headers)
logger.info(
f"[run {workflow_run_id}] Vobiz ring callback - Headers: {json.dumps(all_headers)}"
)
# Parse the callback data from the raw body so signed webhooks can verify
# the exact bytes Vobiz sent without draining the request stream first.
callback_data, raw_body = await parse_webhook_request(request)
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
logger.info(
f"[run {workflow_run_id}] Vobiz ring callback - Body: {json.dumps(callback_data)}"
)
logger.info(
f"[run {workflow_run_id}] Received Vobiz ring callback {json.dumps(callback_data)}"
)
# Verify signature if Vobiz provided any supported signature header.
has_vobiz_signature = any(
header in all_headers
for header in (
"x-vobiz-signature-v3",
"x-vobiz-signature-ma-v3",
"x-vobiz-signature-v2",
"x-vobiz-signature-ma-v2",
)
)
if has_vobiz_signature:
# We need the workflow run to get organization for provider credentials
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for signature verification"
)
return {"status": "error", "reason": "workflow_run_not_found"}
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(
f"[run {workflow_run_id}] Workflow not found for signature verification"
)
return {"status": "error", "reason": "workflow_not_found"}
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
# Verify signature
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = (
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
)
is_valid = await provider.verify_inbound_signature(
webhook_url,
callback_data,
all_headers,
raw_body,
)
if not is_valid:
logger.warning(
f"[run {workflow_run_id}] Invalid Vobiz ring callback signature"
)
return {"status": "error", "reason": "invalid_signature"}
logger.info(f"[run {workflow_run_id}] Vobiz ring callback signature verified")
else:
# Get workflow run for processing (signature verification already got it if needed)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
f"[run {workflow_run_id}] Workflow run not found for Vobiz ring callback"
)
return {"status": "ignored", "reason": "workflow_run_not_found"}
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:
logger.warning(f"[run {workflow_run_id}] Workflow not found")
return {"status": "ignored", "reason": "workflow_not_found"}
provider = await get_telephony_provider_for_run(
workflow_run, workflow.organization_id
)
# Fail closed: reject unsigned/forged ring callbacks before logging them.
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = (
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
)
await _verify_vobiz_callback(
provider,
webhook_url,
callback_data,
all_headers,
raw_body,
log_prefix=f"[run {workflow_run_id}]",
)
# Log the ringing event
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
ring_log = {
@ -308,15 +241,10 @@ async def handle_vobiz_ring_callback(
async def handle_vobiz_hangup_callback_by_workflow(
workflow_id: int,
request: Request,
x_vobiz_signature: Optional[str] = Header(None),
x_vobiz_timestamp: Optional[str] = Header(None),
):
"""Handle Vobiz hangup callback with workflow_id - finds workflow run by call_id."""
all_headers = dict(request.headers)
logger.info(
f"[workflow {workflow_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
)
try:
callback_data, raw_body = await parse_webhook_request(request)
@ -364,35 +292,18 @@ async def handle_vobiz_hangup_callback_by_workflow(
workflow_run, workflow.organization_id
)
has_vobiz_signature = any(
header in all_headers
for header in (
"x-vobiz-signature-v3",
"x-vobiz-signature-ma-v3",
"x-vobiz-signature-v2",
"x-vobiz-signature-ma-v2",
)
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
# before they can mutate call state.
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
await _verify_vobiz_callback(
provider,
webhook_url,
callback_data,
all_headers,
raw_body,
log_prefix=f"[workflow {workflow_id}]",
)
if has_vobiz_signature:
backend_endpoint, _ = await get_backend_endpoints()
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
is_valid = await provider.verify_inbound_signature(
webhook_url,
callback_data,
all_headers,
raw_body,
)
if not is_valid:
logger.warning(
f"[workflow {workflow_id}] Invalid Vobiz hangup callback signature"
)
return {"status": "error", "message": "invalid_signature"}
logger.info(
f"[workflow {workflow_id}] Vobiz hangup callback signature verified"
)
try:
parsed_data = provider.parse_status_callback(callback_data)

View file

@ -66,34 +66,6 @@ async def handle_vonage_events(
logger.error(f"[run {workflow_run_id}] Workflow run not found")
return {"status": "error", "message": "Workflow run not found"}
# For a completed call that includes cost info, capture it immediately
if event_data.get("status") == "completed":
# Vonage sometimes includes price info in the webhook
if "price" in event_data or "rate" in event_data:
try:
if workflow_run.cost_info:
# Store immediate cost info if available
cost_info = workflow_run.cost_info.copy()
if "price" in event_data:
cost_info["vonage_webhook_price"] = float(event_data["price"])
if "rate" in event_data:
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
if "duration" in event_data:
cost_info["vonage_webhook_duration"] = int(
event_data["duration"]
)
await db_client.update_workflow_run(
run_id=workflow_run_id, cost_info=cost_info
)
logger.info(
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
)
except Exception as e:
logger.error(
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
)
# Get workflow and provider
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
if not workflow:

View file

@ -114,11 +114,13 @@ class StatusCallbackRequest(BaseModel):
"NOANSWER": "no-answer",
}
disposition = data.get("disposition", "")
disposition = data.get("disposition") or ""
status = disposition_map.get(disposition.upper(), disposition.lower())
session = data.get("session")
call_id = session.get("token") if isinstance(session, dict) else ""
return cls(
call_id=data.get("session").get("token"),
call_id=call_id or "",
status=status,
from_number=data.get("from"),
to_number=data.get("to"),

View file

@ -35,6 +35,7 @@ import asyncio
from loguru import logger
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.mcp_tool_session import McpToolSession
from api.services.workflow.pipecat_engine_context_composer import (
@ -382,6 +383,9 @@ class PipecatEngine:
embeddings_provider=self._embeddings_provider,
embeddings_endpoint=self._embeddings_endpoint,
embeddings_api_version=self._embeddings_api_version,
correlation_id=self._call_context_vars.get(
MPS_CORRELATION_ID_CONTEXT_KEY
),
tracing_context=self._get_otel_context(),
)

View file

@ -2,7 +2,6 @@
import random
from api.db import db_client
from api.db.models import WorkflowRunModel
from api.services.workflow.dto import QANodeData
@ -43,7 +42,7 @@ async def resolve_llm_config(
async def resolve_user_llm_config(
workflow_run: WorkflowRunModel,
) -> tuple[str, str, str, dict]:
"""Resolve the user's configured LLM (from UserConfiguration).
"""Resolve the user's configured LLM (from EffectiveAIModelConfiguration).
Returns:
(provider, model, api_key, service_kwargs) tuple
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
llm_config: dict = {}
if user_id:
user_configuration = await db_client.get_user_configurations(user_id)
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_configurations = {}
if workflow_run.definition:
workflow_configurations = (
workflow_run.definition.workflow_configurations or {}
)
elif workflow_run.workflow:
workflow_configurations = (
workflow_run.workflow.workflow_configurations or {}
)
user_configuration = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow_run.workflow.organization_id
if workflow_run.workflow
else None,
workflow_configurations=workflow_configurations,
)
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
provider = llm_config.get("provider", "openai")

View file

@ -0,0 +1,41 @@
"""Format workflow run usage for public API responses."""
def format_public_usage_info(usage_info: dict | None) -> dict | None:
if not usage_info:
return None
return {
"llm": usage_info.get("llm") or {},
"tts": usage_info.get("tts") or {},
"stt": usage_info.get("stt") or {},
"call_duration_seconds": usage_info.get("call_duration_seconds"),
}
def format_public_cost_info(
cost_info: dict | None, usage_info: dict | None
) -> dict | None:
"""Return the legacy response shape without doing local cost accounting."""
duration = None
if usage_info and usage_info.get("call_duration_seconds") is not None:
duration = int(round(usage_info.get("call_duration_seconds") or 0))
elif cost_info and cost_info.get("call_duration_seconds") is not None:
duration = int(round(cost_info.get("call_duration_seconds") or 0))
dograh_token_usage = 0
if cost_info:
if "dograh_token_usage" in cost_info:
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
elif "total_cost_usd" in cost_info:
dograh_token_usage = round(
float(cost_info.get("total_cost_usd", 0)) * 100, 2
)
if duration is None and dograh_token_usage == 0:
return None
return {
"dograh_token_usage": dograh_token_usage,
"call_duration_seconds": duration,
}

View file

@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
from api.db import db_client
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.configuration.resolve import resolve_effective_config
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.pipeline_builder import create_pipeline_task
from api.services.pipecat.pipeline_metrics_aggregator import (
@ -410,14 +409,31 @@ async def execute_text_chat_pending_turn(
run_definition = workflow_run.definition
run_configs = run_definition.workflow_configurations or {}
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_run.workflow.user.id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
llm = create_llm_service(user_config)
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
ensure_mps_correlation_id,
)
base_initial_context = dict(workflow_run.initial_context or {})
mps_correlation_id = await ensure_mps_correlation_id(
ai_model_config=user_config,
workflow_run_id=workflow_run_id,
initial_context=base_initial_context,
)
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
inference_llm = llm
runtime_configuration = {
@ -425,9 +441,15 @@ async def execute_text_chat_pending_turn(
"llm_model": user_config.llm.model,
}
initial_context = {
**(workflow_run.initial_context or {}),
**base_initial_context,
"runtime_configuration": runtime_configuration,
}
if mps_correlation_id:
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
workflow_graph = WorkflowGraph(
ReactFlowDTO.model_validate(run_definition.workflow_json)
@ -466,9 +488,17 @@ async def execute_text_chat_pending_turn(
embeddings_model = None
embeddings_base_url = None
if user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
@ -606,8 +636,10 @@ async def execute_text_chat_pending_turn(
"Transportless text chat pipeline failed while closing run {}",
workflow_run_id,
)
await engine.close_mcp_sessions()
await engine.cleanup()
raise
await engine.close_mcp_sessions()
await engine.cleanup()
gathered_context = await engine.get_gathered_context()

View file

@ -4,17 +4,11 @@ from datetime import UTC, datetime
from typing import Any
from uuid import uuid4
from loguru import logger
from api.db import db_client
from api.db.models import WorkflowRunTextSessionModel
from api.db.workflow_run_text_session_client import (
WorkflowRunTextSessionRevisionConflictError,
)
from api.services.pricing.workflow_run_cost import (
apply_usage_delta_to_organization,
build_workflow_run_cost_info,
)
from api.services.workflow.text_chat_logs import (
build_text_chat_realtime_feedback_events,
)
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
state=execution.state,
is_completed=execution.is_completed,
)
workflow_run = await db_client.get_workflow_run_by_id(run_id)
if workflow_run:
try:
# Apply the per-turn delta so org usage tracks cumulative run cost
# without replaying the full session totals on every turn.
await apply_usage_delta_to_organization(workflow_run, execution.usage)
except Exception as e:
logger.error(
f"Failed to update organization usage for text chat run {run_id}: {e}"
)
cost_info = await build_workflow_run_cost_info(workflow_run)
if cost_info is not None:
await db_client.update_workflow_run(run_id, cost_info=cost_info)
return await _reload_text_chat_session(run_id)

View file

@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
tracing_context=None,
) -> Dict[str, Any]:
"""Retrieve relevant information from the knowledge base using vector similarity search.
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Create span with parent context
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
# Add result metadata to span
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
else:
# Tracing is disabled - perform retrieval without tracing
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
embeddings_provider,
embeddings_endpoint,
embeddings_api_version,
correlation_id,
)
@ -220,6 +225,7 @@ async def _perform_retrieval(
embeddings_provider: Optional[str] = None,
embeddings_endpoint: Optional[str] = None,
embeddings_api_version: Optional[str] = None,
correlation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Internal function to perform the actual retrieval operation.
@ -272,11 +278,20 @@ async def _perform_retrieval(
api_version=embeddings_api_version or "2024-02-15-preview",
)
else:
default_headers = None
if (
embeddings_provider == ServiceProviders.DOGRAH.value
and correlation_id
):
default_headers = {
"X-Dograh-Correlation-Id": correlation_id,
}
embedding_service = OpenAIEmbeddingService(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=embeddings_base_url,
default_headers=default_headers,
)
results = await embedding_service.search_similar_chunks(

View file

@ -0,0 +1,111 @@
"""Workflow-run billing hooks.
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
For hosted deployments, Dograh reports completed platform usage to MPS.
When a server-minted MPS correlation id exists, MPS uses model-service usage
as the canonical duration. Otherwise Dograh reports the completed run duration.
"""
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.services.managed_model_services import get_mps_correlation_id
from api.services.mps_service_key_client import mps_service_key_client
def _workflow_run_organization_id(workflow_run) -> int | None:
workflow = getattr(workflow_run, "workflow", None)
return getattr(workflow, "organization_id", None)
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
duration = usage_info.get("call_duration_seconds")
try:
duration_seconds = float(duration)
except (TypeError, ValueError):
return None
return duration_seconds if duration_seconds > 0 else None
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
account = await mps_service_key_client.get_billing_account_status(
organization_id=organization_id
)
return bool(account and account.get("billing_mode") == "v2")
async def report_workflow_run_platform_usage(workflow_run) -> None:
"""Report hosted platform usage for a completed workflow run to MPS."""
if DEPLOYMENT_MODE == "oss":
return
if not getattr(workflow_run, "is_completed", False):
return
organization_id = _workflow_run_organization_id(workflow_run)
if organization_id is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no organization_id",
workflow_run.id,
)
return
correlation_id = get_mps_correlation_id(
getattr(workflow_run, "initial_context", None)
)
duration_seconds = (
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
)
if not correlation_id and duration_seconds is None:
logger.warning(
"Skipping platform usage report for workflow run {}: no billable duration",
workflow_run.id,
)
return
try:
if not await _organization_uses_mps_billing_v2(organization_id):
return
result = await mps_service_key_client.report_platform_usage(
organization_id=organization_id,
correlation_id=correlation_id,
duration_seconds=duration_seconds,
workflow_run_id=workflow_run.id,
metadata={
"source": "workflow_run_completion",
"workflow_id": getattr(workflow_run, "workflow_id", None),
"duration_source": (
"mps_correlation" if correlation_id else "dograh_usage_info"
),
},
)
logger.info(
"Reported platform usage for workflow run {} to MPS: {}",
workflow_run.id,
result,
)
except Exception as e:
logger.error(
"Failed to report platform usage for workflow run {}: {}",
workflow_run.id,
e,
)
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
"""Load a completed workflow run and report platform usage to MPS."""
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Skipping platform usage report: workflow run {} not found",
workflow_run_id,
)
return
await report_workflow_run_platform_usage(workflow_run)