mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
Merge remote-tracking branch 'origin/main' into feat/user-onboarding
This commit is contained in:
commit
093e888ce4
148 changed files with 10908 additions and 2815 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException, Query, WebSocket
|
||||
from fastapi import Depends, Header, HTTPException, Query, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -9,9 +9,10 @@ from api.constants import AUTH_PROVIDER, DOGRAH_MPS_SECRET_KEY, MPS_API_URL
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.auth.stack_auth import stackauth
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import decode_jwt_token
|
||||
|
||||
|
|
@ -110,6 +111,19 @@ async def get_user(
|
|||
# This prevents race conditions where multiple concurrent requests
|
||||
# might try to create configurations
|
||||
if org_was_created:
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization.id,
|
||||
created_by=str(stack_user["id"]),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to initialize hosted MPS billing account for "
|
||||
"organization {}",
|
||||
organization.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
existing_cfg = await db_client.get_user_configurations(user_model.id)
|
||||
if not (existing_cfg.llm or existing_cfg.tts or existing_cfg.stt):
|
||||
mps_config = await create_user_configuration_with_mps_key(
|
||||
|
|
@ -119,6 +133,19 @@ async def get_user(
|
|||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
|
||||
mps_config
|
||||
)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
|
|
@ -129,6 +156,14 @@ async def get_user(
|
|||
return user_model
|
||||
|
||||
|
||||
async def get_user_with_selected_organization(
|
||||
user: Annotated[UserModel, Depends(get_user)],
|
||||
) -> UserModel:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
|
|
@ -192,7 +227,7 @@ async def _handle_api_key_auth(api_key: str) -> UserModel:
|
|||
|
||||
async def create_user_configuration_with_mps_key(
|
||||
user_id: int, organization_id: int, user_provider_id: str
|
||||
) -> Optional[UserConfiguration]:
|
||||
) -> Optional[EffectiveAIModelConfiguration]:
|
||||
"""Create user configuration using MPS service key.
|
||||
|
||||
Args:
|
||||
|
|
@ -201,7 +236,7 @@ async def create_user_configuration_with_mps_key(
|
|||
user_provider_id: The user's provider ID (for created_by field)
|
||||
|
||||
Returns:
|
||||
UserConfiguration with MPS-provided API keys or None if failed
|
||||
EffectiveAIModelConfiguration with MPS-provided API keys or None if failed
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
@ -211,7 +246,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": "Auto-generated key for OSS user",
|
||||
"expires_in_days": 7, # Short-lived for OSS
|
||||
"created_by": user_provider_id,
|
||||
|
|
@ -229,7 +264,7 @@ async def create_user_configuration_with_mps_key(
|
|||
response = await client.post(
|
||||
f"{MPS_API_URL}/api/v1/service-keys/",
|
||||
json={
|
||||
"name": f"Default Dograh Model Service Key",
|
||||
"name": "Default Dograh Model Service Key",
|
||||
"description": f"Auto-generated key for organization {organization_id}",
|
||||
"organization_id": organization_id,
|
||||
"expires_in_days": 90, # Longer-lived for authenticated users
|
||||
|
|
@ -264,8 +299,8 @@ async def create_user_configuration_with_mps_key(
|
|||
"model": "default",
|
||||
},
|
||||
}
|
||||
user_config = UserConfiguration(**configuration)
|
||||
return user_config
|
||||
effective_config = EffectiveAIModelConfiguration(**configuration)
|
||||
return effective_config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to get MPS service key: {response.status_code} - {response.text}"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from api.services.campaign.errors import (
|
|||
PhoneNumberPoolExhaustedError,
|
||||
)
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
|
|||
},
|
||||
)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
error_message = quota_result.error_message or "Quota exceeded"
|
||||
logger.warning(
|
||||
f"Campaign {campaign.id} quota check failed for workflow run "
|
||||
f"{workflow_run.id}: {error_message}"
|
||||
)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
is_completed=True,
|
||||
state=WorkflowRunState.COMPLETED.value,
|
||||
gathered_context={"error": error_message},
|
||||
)
|
||||
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
|
||||
if mapping:
|
||||
org_id, mapped_slot_id = mapping
|
||||
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
|
||||
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
|
||||
|
||||
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
|
||||
workflow_run.id
|
||||
)
|
||||
if from_number_mapping:
|
||||
fn_org_id, fn_number, fn_tcid = from_number_mapping
|
||||
await rate_limiter.release_from_number(
|
||||
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
|
||||
)
|
||||
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
|
||||
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Initiate call via telephony provider
|
||||
try:
|
||||
# Construct webhook URL with parameters
|
||||
|
|
|
|||
484
api/services/configuration/ai_model_configuration.py
Normal file
484
api/services/configuration/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
BYOKAIModelConfiguration,
|
||||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
mask_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAIModelConfiguration:
|
||||
effective: EffectiveAIModelConfiguration
|
||||
source: AIModelConfigurationSource
|
||||
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowAIModelConfigurationMigrationResult:
|
||||
workflow_count: int = 0
|
||||
definition_count: int = 0
|
||||
workflow_ids: list[int] | None = None
|
||||
|
||||
|
||||
async def get_resolved_ai_model_configuration(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
) -> ResolvedAIModelConfiguration:
|
||||
organization_configuration = await get_organization_ai_model_configuration_v2(
|
||||
organization_id
|
||||
)
|
||||
if organization_configuration is not None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=compile_ai_model_configuration_v2(organization_configuration),
|
||||
source="organization_v2",
|
||||
organization_configuration=organization_configuration,
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=EffectiveAIModelConfiguration(),
|
||||
source="empty",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user_id)
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=legacy,
|
||||
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
|
||||
)
|
||||
|
||||
|
||||
async def get_effective_ai_model_configuration_for_workflow(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
workflow_configurations: dict | None,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
workflow_configurations = workflow_configurations or {}
|
||||
v2_override = workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
if v2_override:
|
||||
return compile_ai_model_configuration_v2(
|
||||
OrganizationAIModelConfigurationV2.model_validate(v2_override)
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return resolve_effective_config(
|
||||
resolved_config.effective,
|
||||
workflow_configurations.get("model_overrides"),
|
||||
)
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_v2(
|
||||
organization_id: int | None,
|
||||
) -> OrganizationAIModelConfigurationV2 | None:
|
||||
if organization_id is None:
|
||||
return None
|
||||
row = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
)
|
||||
if row is None or not row.value:
|
||||
return None
|
||||
try:
|
||||
return OrganizationAIModelConfigurationV2.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration v2 for organization "
|
||||
f"{organization_id}: {exc}. Falling back to legacy configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_v2(
|
||||
organization_id: int,
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
configuration.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
||||
async def migrate_workflow_model_configurations_to_v2(
|
||||
*,
|
||||
organization_id: int,
|
||||
fallback_user_config: EffectiveAIModelConfiguration,
|
||||
) -> WorkflowAIModelConfigurationMigrationResult:
|
||||
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
|
||||
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
|
||||
workflow_updates: list[tuple[int, dict]] = []
|
||||
definition_updates: list[tuple[int, dict]] = []
|
||||
migrated_workflow_ids: set[int] = set()
|
||||
|
||||
for workflow in workflows:
|
||||
base_config = fallback_user_config
|
||||
if workflow.user_id is not None:
|
||||
if workflow.user_id not in owner_configs:
|
||||
owner_configs[
|
||||
workflow.user_id
|
||||
] = await db_client.get_user_configurations(workflow.user_id)
|
||||
base_config = owner_configs[workflow.user_id]
|
||||
|
||||
workflow_configs, workflow_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if workflow_changed:
|
||||
workflow_updates.append((workflow.id, workflow_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
for definition in workflow.definitions:
|
||||
definition_configs, definition_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
definition.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if definition_changed:
|
||||
definition_updates.append((definition.id, definition_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
if workflow_updates or definition_updates:
|
||||
async with db_client.async_session() as session:
|
||||
for workflow_id, workflow_configs in workflow_updates:
|
||||
await session.execute(
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.values(workflow_configurations=workflow_configs)
|
||||
)
|
||||
for definition_id, definition_configs in definition_updates:
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(WorkflowDefinitionModel.id == definition_id)
|
||||
.values(workflow_configurations=definition_configs)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return WorkflowAIModelConfigurationMigrationResult(
|
||||
workflow_count=len(migrated_workflow_ids),
|
||||
definition_count=len(definition_updates),
|
||||
workflow_ids=sorted(migrated_workflow_ids),
|
||||
)
|
||||
|
||||
|
||||
def migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations: dict | None,
|
||||
base_config: EffectiveAIModelConfiguration,
|
||||
) -> tuple[dict, bool]:
|
||||
if not isinstance(workflow_configurations, dict):
|
||||
return {}, False
|
||||
|
||||
migrated = copy.deepcopy(workflow_configurations)
|
||||
model_overrides = migrated.get("model_overrides")
|
||||
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
|
||||
if not isinstance(model_overrides, dict):
|
||||
if "model_overrides" in migrated:
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
return migrated, False
|
||||
|
||||
if not existing_v2_override:
|
||||
effective = resolve_effective_config(base_config, model_overrides)
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
|
||||
|
||||
def merge_ai_model_configuration_v2_secrets(
|
||||
incoming: OrganizationAIModelConfigurationV2,
|
||||
existing: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
if existing is None:
|
||||
return incoming
|
||||
|
||||
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
|
||||
existing_dict = existing.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
|
||||
incoming_dograh = incoming_dict.get("dograh") or {}
|
||||
existing_dograh = existing_dict.get("dograh") or {}
|
||||
incoming_key = incoming_dograh.get("api_key")
|
||||
existing_key = existing_dograh.get("api_key")
|
||||
if incoming_key and existing_key and contains_masked_key(incoming_key):
|
||||
incoming_dograh["api_key"] = resolve_masked_api_keys(
|
||||
incoming_key,
|
||||
existing_key,
|
||||
)
|
||||
|
||||
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
|
||||
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
|
||||
|
||||
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
|
||||
|
||||
|
||||
def check_for_masked_keys_in_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> None:
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_raise_if_masked_secret(data)
|
||||
|
||||
|
||||
def mask_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> dict | None:
|
||||
if configuration is None:
|
||||
return None
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_mask_secret_fields(data)
|
||||
return data
|
||||
|
||||
|
||||
def convert_legacy_ai_model_configuration_to_v2(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
dograh_key = _first_dograh_api_key(configuration)
|
||||
if dograh_key:
|
||||
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
|
||||
|
||||
if configuration.is_realtime:
|
||||
if configuration.realtime is None or configuration.llm is None:
|
||||
raise ValueError("Realtime legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="realtime",
|
||||
realtime=BYOKRealtimeAIModelConfiguration(
|
||||
realtime=configuration.realtime,
|
||||
llm=configuration.llm,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
configuration.llm is None
|
||||
or configuration.tts is None
|
||||
or configuration.stt is None
|
||||
):
|
||||
raise ValueError("Pipeline legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="pipeline",
|
||||
pipeline=BYOKPipelineAIModelConfiguration(
|
||||
llm=configuration.llm,
|
||||
tts=configuration.tts,
|
||||
stt=configuration.stt,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
def apply_managed_embeddings_base_url(
|
||||
*,
|
||||
provider: str | None,
|
||||
base_url: str | None,
|
||||
) -> str | None:
|
||||
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
|
||||
return dograh_embeddings_base_url()
|
||||
return base_url
|
||||
|
||||
|
||||
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
|
||||
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
|
||||
return
|
||||
incoming_mode = incoming_byok.get("mode")
|
||||
existing_mode = existing_byok.get("mode")
|
||||
if incoming_mode != existing_mode:
|
||||
return
|
||||
section_names = (
|
||||
("llm", "tts", "stt", "embeddings")
|
||||
if incoming_mode == "pipeline"
|
||||
else ("realtime", "llm", "embeddings")
|
||||
)
|
||||
incoming_container = incoming_byok.get(incoming_mode)
|
||||
existing_container = existing_byok.get(existing_mode)
|
||||
if not isinstance(incoming_container, dict) or not isinstance(
|
||||
existing_container, dict
|
||||
):
|
||||
return
|
||||
for section_name in section_names:
|
||||
incoming_section = incoming_container.get(section_name)
|
||||
existing_section = existing_container.get(section_name)
|
||||
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
|
||||
_merge_service_secret_fields(incoming_section, existing_section)
|
||||
|
||||
|
||||
async def _list_workflows_for_model_configuration_migration(
|
||||
organization_id: int,
|
||||
) -> list[WorkflowModel]:
|
||||
async with db_client.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.definitions))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
return list(result.scalars().unique().all())
|
||||
|
||||
|
||||
def _merge_service_secret_fields(incoming: dict, existing: dict):
|
||||
if (
|
||||
incoming.get("provider") is not None
|
||||
and existing.get("provider") is not None
|
||||
and incoming.get("provider") != existing.get("provider")
|
||||
):
|
||||
return
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing:
|
||||
continue
|
||||
incoming_secret = incoming.get(secret_field)
|
||||
existing_secret = existing[secret_field]
|
||||
if incoming_secret is None:
|
||||
incoming[secret_field] = existing_secret
|
||||
elif contains_masked_key(incoming_secret):
|
||||
incoming[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_masked_secret(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
|
||||
raise ValueError(
|
||||
f"The {key} appears to be masked. Please provide the actual "
|
||||
"value, not the masked value."
|
||||
)
|
||||
_raise_if_masked_secret(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_raise_if_masked_secret(item)
|
||||
|
||||
|
||||
def _mask_secret_fields(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_secret_fields(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_secret_fields(item)
|
||||
|
||||
|
||||
def _mask_secret_value(value):
|
||||
if isinstance(value, list):
|
||||
return [mask_key(item) for item in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
|
||||
return any(
|
||||
service is not None
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_any_dograh_legacy_configuration(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
dograh_key: str,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
speed = getattr(configuration.tts, "speed", 1.0)
|
||||
if speed not in DOGRAH_SPEED_OPTIONS:
|
||||
speed = 1.0
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key=dograh_key,
|
||||
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
|
||||
or DOGRAH_DEFAULT_VOICE,
|
||||
speed=speed,
|
||||
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
|
||||
or DOGRAH_DEFAULT_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
):
|
||||
if service is None or _provider(service) != ServiceProviders.DOGRAH:
|
||||
continue
|
||||
try:
|
||||
return _single_api_key(service)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _provider(service):
|
||||
return getattr(service, "provider", None)
|
||||
|
||||
|
||||
def _single_api_key(service) -> str:
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Expected exactly one API key")
|
||||
return keys[0]
|
||||
key = getattr(service, "api_key", None)
|
||||
if not key:
|
||||
raise ValueError("Expected an API key")
|
||||
return key
|
||||
|
|
@ -8,8 +8,8 @@ from groq import Groq
|
|||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
UserConfiguration,
|
||||
from api.schemas.ai_model_configuration import (
|
||||
EffectiveAIModelConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
|
@ -64,7 +64,7 @@ class UserConfigurationValidator:
|
|||
|
||||
async def validate(
|
||||
self,
|
||||
configuration: UserConfiguration,
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> APIKeyStatusResponse:
|
||||
|
|
@ -75,21 +75,21 @@ class UserConfigurationValidator:
|
|||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The rules are simple:
|
|||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ def contains_masked_key(value: str | list[str] | None) -> bool:
|
|||
return any(MASK_MARKER in k for k in keys)
|
||||
|
||||
|
||||
def check_for_masked_keys(config: "UserConfiguration") -> None:
|
||||
def check_for_masked_keys(config: "EffectiveAIModelConfiguration") -> None:
|
||||
"""Raise ValueError if any service in *config* still has a masked secret."""
|
||||
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
|
||||
service = getattr(config, field, None)
|
||||
|
|
@ -111,7 +111,7 @@ def resolve_masked_api_keys(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level helpers for UserConfiguration objects
|
||||
# High-level helpers for EffectiveAIModelConfiguration objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
|
|||
return data
|
||||
|
||||
|
||||
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
||||
def mask_user_config(config: EffectiveAIModelConfiguration) -> Dict[str, Any]:
|
||||
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
|
||||
|
||||
return {
|
||||
|
|
@ -155,21 +155,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
|||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
if isinstance(model_overrides, dict):
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
v2_override = masked.get("model_configuration_v2_override")
|
||||
if isinstance(v2_override, dict):
|
||||
_mask_nested_service_secrets(v2_override)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
def _mask_nested_service_secrets(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_nested_service_secrets(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_nested_service_secrets(item)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
|
|
@ -66,9 +66,9 @@ def _merge_service_secret_fields(
|
|||
|
||||
|
||||
def merge_user_configurations(
|
||||
existing: UserConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> UserConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
|
||||
existing: EffectiveAIModelConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new EffectiveAIModelConfiguration.
|
||||
|
||||
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
|
||||
extracted via Pydantic `model_dump`).
|
||||
|
|
@ -113,14 +113,14 @@ def merge_user_configurations(
|
|||
if "timezone" in incoming_partial:
|
||||
merged["timezone"] = incoming_partial["timezone"]
|
||||
|
||||
# Onboarding gate flags — overwrite only when supplied (set once on submit/skip).
|
||||
# Onboarding gate flags: overwrite only when supplied.
|
||||
if "onboarding_completed_at" in incoming_partial:
|
||||
merged["onboarding_completed_at"] = incoming_partial["onboarding_completed_at"]
|
||||
|
||||
if "onboarding_skipped" in incoming_partial:
|
||||
merged["onboarding_skipped"] = incoming_partial["onboarding_skipped"]
|
||||
|
||||
return UserConfiguration.model_validate(merged)
|
||||
return EffectiveAIModelConfiguration.model_validate(merged)
|
||||
|
||||
|
||||
def merge_workflow_configuration_secrets(
|
||||
|
|
|
|||
|
|
@ -911,7 +911,7 @@ class DograhTTSService(BaseTTSConfiguration):
|
|||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice.")
|
||||
|
||||
|
||||
CARTESIA_TTS_MODELS = ["sonic-3"]
|
||||
CARTESIA_TTS_MODELS = ["sonic-3.5", "sonic-3"]
|
||||
|
||||
|
||||
@register_tts
|
||||
|
|
@ -919,7 +919,7 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
|||
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
|
||||
model: str = Field(
|
||||
default="sonic-3",
|
||||
default="sonic-3.5",
|
||||
description="Cartesia TTS model.",
|
||||
json_schema_extra={"examples": CARTESIA_TTS_MODELS},
|
||||
)
|
||||
|
|
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
DograhEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
)
|
||||
|
||||
# Maps override key → (UserConfiguration field, ServiceType for registry lookup)
|
||||
# Maps override key → (EffectiveAIModelConfiguration field, ServiceType for registry lookup)
|
||||
_SECTION_MAP: dict[str, ServiceType] = {
|
||||
"llm": ServiceType.LLM,
|
||||
"tts": ServiceType.TTS,
|
||||
|
|
@ -36,7 +36,7 @@ _SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
|
|||
|
||||
def enrich_overrides_with_api_keys(
|
||||
model_overrides: dict,
|
||||
user_config: UserConfiguration,
|
||||
user_config: EffectiveAIModelConfiguration,
|
||||
) -> dict:
|
||||
"""Copy API keys from the global config into model_overrides where missing.
|
||||
|
||||
|
|
@ -74,9 +74,9 @@ def enrich_overrides_with_api_keys(
|
|||
|
||||
|
||||
def resolve_effective_config(
|
||||
user_config: UserConfiguration,
|
||||
user_config: EffectiveAIModelConfiguration,
|
||||
model_overrides: dict | None,
|
||||
) -> UserConfiguration:
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
"""Deep-merge workflow model_overrides onto global user config.
|
||||
|
||||
- If model_overrides is None or empty, returns a copy of user_config unchanged.
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
api_key: Optional[str] = None,
|
||||
model_id: str = DEFAULT_MODEL_ID,
|
||||
base_url: Optional[str] = None,
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initialize the OpenAI embedding service.
|
||||
|
||||
|
|
@ -60,6 +61,8 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
if default_headers:
|
||||
client_kwargs["default_headers"] = default_headers
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
else:
|
||||
|
|
|
|||
78
api/services/managed_model_services.py
Normal file
78
api/services/managed_model_services.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
||||
def uses_managed_model_services_v2(
|
||||
ai_model_config: EffectiveAIModelConfiguration | None,
|
||||
) -> bool:
|
||||
if (
|
||||
ai_model_config is None
|
||||
or getattr(ai_model_config, "managed_service_version", None) != 2
|
||||
):
|
||||
return False
|
||||
|
||||
return any(
|
||||
_is_dograh_service(getattr(ai_model_config, section_name, None))
|
||||
for section_name in ("llm", "tts", "stt", "embeddings")
|
||||
)
|
||||
|
||||
|
||||
def get_mps_correlation_id(initial_context: dict[str, Any] | None) -> str | None:
|
||||
if not initial_context:
|
||||
return None
|
||||
correlation_id = initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY)
|
||||
if correlation_id is None:
|
||||
return None
|
||||
return str(correlation_id)
|
||||
|
||||
|
||||
async def ensure_mps_correlation_id(
|
||||
*,
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
workflow_run_id: int,
|
||||
initial_context: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
existing = get_mps_correlation_id(initial_context)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires workflow run authorization before "
|
||||
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
|
||||
)
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
service = getattr(ai_model_config, section_name, None)
|
||||
if not _is_dograh_service(service):
|
||||
continue
|
||||
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if keys:
|
||||
return keys[0]
|
||||
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if isinstance(api_key, str) and api_key:
|
||||
return api_key
|
||||
|
||||
return None
|
||||
23
api/services/mps_billing.py
Normal file
23
api/services/mps_billing.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
async def ensure_hosted_mps_billing_account_v2(
|
||||
organization_id: int,
|
||||
*,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Ensure hosted orgs have an MPS billing v2 account.
|
||||
|
||||
OSS deployments use legacy per-key quota accounting and do not create MPS
|
||||
billing accounts.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return None
|
||||
|
||||
return await mps_service_key_client.ensure_billing_account_v2(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
|
@ -4,6 +4,7 @@ This client communicates with the Model Proxy Service (MPS) for service key mana
|
|||
Service keys are stored and managed entirely in MPS, not in the local database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -353,6 +354,278 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def create_credit_purchase_url(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
return_url: Optional[str] = None,
|
||||
billing_details: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create a short-lived MPS checkout URL for adding organization credits."""
|
||||
payload = {
|
||||
"created_by": created_by,
|
||||
"return_url": return_url,
|
||||
"billing_details": billing_details or {},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/checkout-sessions",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create MPS credit purchase URL: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create MPS credit purchase URL: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_credit_ledger(
|
||||
self,
|
||||
organization_id: int,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get the MPS v2 billing account balance and recent credit ledger."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/ledger",
|
||||
params={"page": page, "limit": limit},
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS credit ledger: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS credit ledger: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_billing_account_status(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Get an existing MPS v2 billing account without creating one."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/status",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to get MPS billing account status: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get MPS billing account status: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def ensure_billing_account_v2(
|
||||
self,
|
||||
organization_id: int,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create or return the MPS v2 billing account for an organization."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/balance",
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to ensure MPS billing account v2: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to ensure MPS billing account v2: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
service_key: Optional[str] = None,
|
||||
require_correlation_id: bool = False,
|
||||
minimum_credits: float | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
|
||||
payload = {
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"service_key": service_key,
|
||||
"require_correlation_id": require_correlation_id,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if minimum_credits is not None:
|
||||
payload["minimum_credits"] = minimum_credits
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to authorize MPS workflow run start: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to authorize MPS workflow run start: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
service_key: str,
|
||||
workflow_run_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Mint a server-generated correlation ID for managed model services."""
|
||||
payload: dict[str, int] = {}
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/correlation-id/self",
|
||||
json=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {service_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to create correlation ID: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to create correlation ID: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def report_platform_usage(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
correlation_id: Optional[str] = None,
|
||||
duration_seconds: Optional[float] = None,
|
||||
workflow_run_id: int | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
max_attempts: int = 3,
|
||||
) -> dict:
|
||||
"""Report hosted Dograh platform usage for a completed workflow run."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise ValueError("OSS deployments must not report platform usage to MPS")
|
||||
if not correlation_id and duration_seconds is None:
|
||||
raise ValueError(
|
||||
"Platform usage reports require correlation_id or duration_seconds"
|
||||
)
|
||||
|
||||
payload: dict = {
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if correlation_id:
|
||||
payload["correlation_id"] = correlation_id
|
||||
if duration_seconds is not None:
|
||||
payload["duration_seconds"] = duration_seconds
|
||||
if workflow_run_id is not None:
|
||||
payload["workflow_run_id"] = workflow_run_id
|
||||
|
||||
max_attempts = max(1, max_attempts)
|
||||
last_response: httpx.Response | None = None
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
response = await client.post(
|
||||
(
|
||||
f"{self.base_url}/api/v1/billing/accounts/"
|
||||
f"{organization_id}/platform-usage"
|
||||
),
|
||||
json=payload,
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
last_response = response
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
should_retry = (
|
||||
response.status_code == 409
|
||||
and "usage_not_ready" in response.text
|
||||
and attempt < max_attempts
|
||||
)
|
||||
if should_retry:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
|
||||
logger.error(
|
||||
"Failed to report platform usage: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to report platform usage: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
raise httpx.HTTPStatusError(
|
||||
"Failed to report platform usage",
|
||||
request=last_response.request,
|
||||
response=last_response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
|
|
|
|||
50
api/services/organization_context.py
Normal file
50
api/services/organization_context.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
|
||||
class OrganizationModelServicesContext(BaseModel):
|
||||
config_source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
has_model_configuration_v2: bool
|
||||
managed_service_version: Optional[int] = None
|
||||
uses_managed_service_v2: bool
|
||||
|
||||
|
||||
class OrganizationContextResponse(BaseModel):
|
||||
organization_id: Optional[int] = None
|
||||
organization_provider_id: Optional[str] = None
|
||||
model_services: OrganizationModelServicesContext
|
||||
|
||||
|
||||
async def get_organization_context(user: UserModel) -> OrganizationContextResponse:
|
||||
organization_id = user.selected_organization_id
|
||||
organization = (
|
||||
await db_client.get_organization_by_id(organization_id)
|
||||
if organization_id
|
||||
else None
|
||||
)
|
||||
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
managed_service_version = resolved.effective.managed_service_version
|
||||
|
||||
return OrganizationContextResponse(
|
||||
organization_id=organization_id,
|
||||
organization_provider_id=organization.provider_id if organization else None,
|
||||
model_services=OrganizationModelServicesContext(
|
||||
config_source=resolved.source,
|
||||
has_model_configuration_v2=resolved.source == "organization_v2",
|
||||
managed_service_version=managed_service_version,
|
||||
uses_managed_service_v2=(
|
||||
resolved.source == "organization_v2" and managed_service_version == 2
|
||||
),
|
||||
),
|
||||
)
|
||||
62
api/services/organization_preferences.py
Normal file
62
api/services/organization_preferences.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from inspect import isawaitable
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
|
||||
|
||||
async def get_organization_preferences(
|
||||
organization_id: int | None,
|
||||
db=None,
|
||||
) -> OrganizationPreferences:
|
||||
if organization_id is None:
|
||||
return OrganizationPreferences()
|
||||
|
||||
db = db or db_client
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
)
|
||||
if row is None:
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
)
|
||||
return _parse_preferences(row.value if row is not None else None, organization_id)
|
||||
|
||||
|
||||
async def upsert_organization_preferences(
|
||||
organization_id: int,
|
||||
preferences: OrganizationPreferences,
|
||||
) -> OrganizationPreferences:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
preferences.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return preferences
|
||||
|
||||
|
||||
async def _get_configuration(db, organization_id: int, key: str):
|
||||
row = db.get_configuration(organization_id, key)
|
||||
if isawaitable(row):
|
||||
row = await row
|
||||
return row
|
||||
|
||||
|
||||
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
|
||||
if not value or not isinstance(value, dict):
|
||||
return OrganizationPreferences()
|
||||
try:
|
||||
return OrganizationPreferences.model_validate(value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid organization preferences for organization "
|
||||
f"{organization_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OrganizationPreferences()
|
||||
|
|
@ -15,6 +15,29 @@ from api.utils.credential_auth import build_auth_header
|
|||
PRE_CALL_FETCH_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
def _extract_initial_context(response_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pull the context variables out of a pre-call fetch response.
|
||||
|
||||
The canonical key is ``initial_context``. The legacy ``dynamic_variables``
|
||||
key is still accepted for backward compatibility, so existing endpoints
|
||||
keep working; ``initial_context`` takes precedence when both are present.
|
||||
|
||||
Either key may appear at the top level or nested under ``call_inbound``:
|
||||
{"call_inbound": {"initial_context": {...}}} | {"initial_context": {...}}
|
||||
{"call_inbound": {"dynamic_variables": {...}}} | {"dynamic_variables": {...}}
|
||||
"""
|
||||
container = response_data.get("call_inbound")
|
||||
if not isinstance(container, dict):
|
||||
container = response_data
|
||||
|
||||
for key in ("initial_context", "dynamic_variables"):
|
||||
value = container.get(key)
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def execute_pre_call_fetch(
|
||||
*,
|
||||
url: str,
|
||||
|
|
@ -77,24 +100,16 @@ async def execute_pre_call_fetch(
|
|||
)
|
||||
return {}
|
||||
|
||||
# Extract dynamic_variables from Retell-compatible response
|
||||
# Supports: {call_inbound: {dynamic_variables: {...}}}
|
||||
# or: {dynamic_variables: {...}}
|
||||
dynamic_vars = {}
|
||||
call_inbound = response_data.get("call_inbound")
|
||||
if isinstance(call_inbound, dict):
|
||||
dynamic_vars = call_inbound.get("dynamic_variables", {})
|
||||
elif "dynamic_variables" in response_data:
|
||||
dynamic_vars = response_data["dynamic_variables"]
|
||||
|
||||
if not isinstance(dynamic_vars, dict):
|
||||
dynamic_vars = {}
|
||||
# Extract the variables to merge into initial_context. Prefers
|
||||
# the canonical `initial_context` key, falling back to the
|
||||
# legacy `dynamic_variables` key for backward compatibility.
|
||||
initial_context_vars = _extract_initial_context(response_data)
|
||||
|
||||
logger.info(
|
||||
f"Pre-call fetch: success ({response.status_code}), "
|
||||
f"dynamic_variables keys: {list(dynamic_vars.keys())}"
|
||||
f"initial_context keys: {list(initial_context_vars.keys())}"
|
||||
)
|
||||
return dynamic_vars
|
||||
return initial_context_vars
|
||||
else:
|
||||
logger.warning(
|
||||
f"Pre-call fetch: HTTP {response.status_code} - "
|
||||
|
|
|
|||
|
|
@ -162,15 +162,13 @@ async def run_pipeline_telephony(
|
|||
workflow_id: Workflow being executed.
|
||||
workflow_run_id: Workflow run row.
|
||||
user_id: Owner of the workflow.
|
||||
call_id: Provider call identifier (stored in cost_info for billing).
|
||||
call_id: Provider call identifier.
|
||||
transport_kwargs: Provider-specific kwargs forwarded to the transport
|
||||
factory (e.g. stream_sid + call_sid for Twilio).
|
||||
"""
|
||||
logger.debug(f"Running {provider_name} pipeline for workflow_run {workflow_run_id}")
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
await db_client.update_workflow_run(workflow_run_id, cost_info={"call_id": call_id})
|
||||
|
||||
workflow = await db_client.get_workflow(workflow_id, user_id)
|
||||
if workflow:
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
|
@ -195,14 +193,17 @@ async def run_pipeline_telephony(
|
|||
# Resolve effective user config here so the transport can tune its
|
||||
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
|
||||
# values into _run_pipeline so it doesn't fetch them again.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -272,15 +273,18 @@ async def run_pipeline_smallwebrtc(
|
|||
# Resolve workflow_run + effective user_config here so the transport can
|
||||
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
|
||||
# reuses these via kwargs so we don't fetch twice.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -334,7 +338,7 @@ async def _run_pipeline(
|
|||
if workflow_run.is_completed:
|
||||
raise HTTPException(status_code=400, detail="Workflow run already completed")
|
||||
|
||||
merged_call_context_vars = workflow_run.initial_context
|
||||
merged_call_context_vars = dict(workflow_run.initial_context or {})
|
||||
# If there is some extra call_context_vars, fold them in. Persistence
|
||||
# happens once below, after runtime_configuration is also resolved.
|
||||
if call_context_vars:
|
||||
|
|
@ -380,15 +384,31 @@ async def _run_pipeline(
|
|||
# Resolve model overrides from the version onto global user config (skip
|
||||
# when the caller already resolved it).
|
||||
if resolved_user_config is None:
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
else:
|
||||
user_config = resolved_user_config
|
||||
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=merged_call_context_vars,
|
||||
)
|
||||
if mps_correlation_id:
|
||||
merged_call_context_vars[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
|
||||
# Detect realtime mode (speech-to-speech services like OpenAI Realtime, Gemini Live)
|
||||
is_realtime = user_config.is_realtime and user_config.realtime is not None
|
||||
|
||||
|
|
@ -400,11 +420,23 @@ async def _run_pipeline(
|
|||
# Realtime services don't implement run_inference, so create a
|
||||
# separate text LLM for variable extraction and other out-of-band
|
||||
# inference calls.
|
||||
inference_llm = create_llm_service(user_config)
|
||||
inference_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
llm = create_llm_service(user_config)
|
||||
stt = create_stt_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
keyterms=keyterms,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
tts = create_tts_service(
|
||||
user_config,
|
||||
audio_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = None
|
||||
|
||||
# Stamp the providers/models actually resolved for this run onto
|
||||
|
|
@ -508,10 +540,17 @@ async def _run_pipeline(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if user_config and user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
|
|
@ -679,7 +718,10 @@ async def _run_pipeline(
|
|||
# Create a separate LLM instance for the voicemail sub-pipeline
|
||||
# (can't share with main pipeline as it would mess up frame linking)
|
||||
if voicemail_config.get("use_workflow_llm", True):
|
||||
voicemail_llm = create_llm_service(user_config)
|
||||
voicemail_llm = create_llm_service(
|
||||
user_config,
|
||||
correlation_id=mps_correlation_id,
|
||||
)
|
||||
else:
|
||||
voicemail_llm = create_llm_service_from_provider(
|
||||
provider=voicemail_config.get("provider", "openai"),
|
||||
|
|
|
|||
|
|
@ -78,7 +78,10 @@ def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
|||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
user_config,
|
||||
audio_config: "AudioConfig",
|
||||
keyterms: list[str] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
):
|
||||
"""Create and return appropriate STT service based on user configuration
|
||||
|
||||
|
|
@ -160,6 +163,7 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
|
|
@ -286,7 +290,9 @@ def create_stt_service(
|
|||
)
|
||||
|
||||
|
||||
def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||
def create_tts_service(
|
||||
user_config, audio_config: "AudioConfig", correlation_id: str | None = None
|
||||
):
|
||||
"""Create and return appropriate TTS service based on user configuration
|
||||
|
||||
Args:
|
||||
|
|
@ -404,6 +410,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
|
|
@ -564,6 +571,7 @@ def create_llm_service_from_provider(
|
|||
model: str,
|
||||
api_key: str | None,
|
||||
*,
|
||||
correlation_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
aws_access_key: str | None = None,
|
||||
|
|
@ -637,6 +645,7 @@ def create_llm_service_from_provider(
|
|||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=api_key,
|
||||
correlation_id=correlation_id,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
|
|
@ -851,7 +860,7 @@ def create_realtime_llm_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
|
||||
|
||||
def create_llm_service(user_config):
|
||||
def create_llm_service(user_config, correlation_id: str | None = None):
|
||||
"""Create and return appropriate LLM service based on user configuration."""
|
||||
provider = user_config.llm.provider
|
||||
model = user_config.llm.model
|
||||
|
|
@ -880,4 +889,10 @@ def create_llm_service(user_config):
|
|||
elif provider == ServiceProviders.SARVAM.value:
|
||||
kwargs["temperature"] = user_config.llm.temperature
|
||||
|
||||
return create_llm_service_from_provider(provider, model, api_key, **kwargs)
|
||||
return create_llm_service_from_provider(
|
||||
provider,
|
||||
model,
|
||||
api_key,
|
||||
correlation_id=correlation_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
# Pricing Module
|
||||
|
||||
This module contains pricing models and registries for different AI services used in workflow cost calculations.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
pricing/
|
||||
├── __init__.py # Main module exports
|
||||
├── models.py # Base pricing model classes
|
||||
├── llm.py # LLM pricing configurations
|
||||
├── tts.py # TTS pricing configurations
|
||||
├── stt.py # STT pricing configurations
|
||||
├── registry.py # Combined pricing registry
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Pricing Models
|
||||
|
||||
### TokenPricingModel
|
||||
Used for LLM services that charge based on tokens:
|
||||
- `prompt_token_price`: Cost per prompt token
|
||||
- `completion_token_price`: Cost per completion token
|
||||
- `cache_read_discount`: Discount for cache read tokens (default 50%)
|
||||
- `cache_creation_multiplier`: Premium for cache creation tokens (default 25%)
|
||||
|
||||
### CharacterPricingModel
|
||||
Used for TTS services that charge based on character count:
|
||||
- `character_price`: Cost per character
|
||||
|
||||
### TimePricingModel
|
||||
Used for STT services that charge based on time:
|
||||
- `second_price`: Cost per second
|
||||
|
||||
## Adding New Pricing
|
||||
|
||||
### Adding a New LLM Model
|
||||
Edit `llm.py` and add the model to the appropriate provider:
|
||||
|
||||
```python
|
||||
ServiceProviders.OPENAI: {
|
||||
"new-model": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000,
|
||||
completion_token_price=Decimal("8.00") / 1000000,
|
||||
),
|
||||
# ... existing models
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a New Provider
|
||||
1. Add pricing configurations to the appropriate service file (llm.py, tts.py, stt.py)
|
||||
2. The registry will automatically include them
|
||||
|
||||
### Adding a New Service Type
|
||||
1. Create a new pricing file (e.g., `image.py`)
|
||||
2. Define the pricing models
|
||||
3. Import and add to `registry.py`
|
||||
|
||||
## Usage
|
||||
|
||||
The pricing registry is automatically imported and used by the cost calculator:
|
||||
|
||||
```python
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.workflow.cost_calculator import cost_calculator
|
||||
|
||||
# The cost calculator uses the pricing registry automatically
|
||||
result = cost_calculator.calculate_total_cost(usage_info)
|
||||
```
|
||||
|
||||
## Maintenance
|
||||
|
||||
- Update pricing when providers change their rates
|
||||
- All prices should use `Decimal` for precision
|
||||
- Include comments with current pricing from provider documentation
|
||||
- Test changes with existing test suite
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""
|
||||
Pricing module for workflow cost calculation.
|
||||
|
||||
This module contains pricing models and registries for different AI services.
|
||||
"""
|
||||
|
||||
from .registry import PRICING_REGISTRY
|
||||
|
||||
__all__ = ["PRICING_REGISTRY"]
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
"""
|
||||
Cost Calculator for Workflow Runs
|
||||
|
||||
This module provides a comprehensive cost calculation system for workflow runs based on usage metrics
|
||||
from different AI service providers (OpenAI, Groq, Deepgram, etc.).
|
||||
|
||||
Features:
|
||||
- Token-based pricing for LLM services with cache optimization support
|
||||
- Character-based pricing for TTS services
|
||||
- Time-based pricing for STT services
|
||||
- Configurable pricing models that can be updated
|
||||
- Support for multiple providers and models
|
||||
- Automatic provider inference from model names
|
||||
- JSON serialization support for database storage
|
||||
|
||||
Usage:
|
||||
from api.tasks.cost_calculator import cost_calculator
|
||||
|
||||
usage_info = {
|
||||
"llm": {
|
||||
"processor_name|||gpt-4o": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"total_tokens": 1500,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
},
|
||||
"tts": {
|
||||
"processor_name|||aura-2-helena-en": 2000 # character count
|
||||
}
|
||||
}
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
print(f"Total cost: ${cost_breakdown['total']:.6f}")
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pricing import PRICING_REGISTRY
|
||||
from api.services.pricing.models import (
|
||||
PricingModel,
|
||||
)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""Main cost calculator class"""
|
||||
|
||||
def __init__(self, pricing_registry: Dict = None):
|
||||
self.pricing_registry = pricing_registry or PRICING_REGISTRY
|
||||
|
||||
def get_pricing_model(
|
||||
self, service_type: str, provider: str, model: str
|
||||
) -> Optional[PricingModel]:
|
||||
"""Get pricing model for a specific service, provider, and model"""
|
||||
try:
|
||||
service_pricing = self.pricing_registry.get(service_type, {})
|
||||
|
||||
# Try to get pricing for the specific provider
|
||||
provider_pricing = service_pricing.get(provider, {})
|
||||
pricing_model = provider_pricing.get(model) or provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
if pricing_model:
|
||||
return pricing_model
|
||||
|
||||
# If not found, try the "default" provider for this service type
|
||||
default_provider_pricing = service_pricing.get("default", {})
|
||||
return default_provider_pricing.get(model) or default_provider_pricing.get(
|
||||
"default"
|
||||
)
|
||||
|
||||
except (KeyError, AttributeError):
|
||||
return None
|
||||
|
||||
def calculate_llm_cost(
|
||||
self, provider: str, model: str, usage: Dict[str, int]
|
||||
) -> Decimal:
|
||||
"""Calculate cost for LLM usage"""
|
||||
pricing_model = self.get_pricing_model("llm", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(usage)
|
||||
|
||||
def calculate_tts_cost(
|
||||
self, provider: str, model: str, character_count: int
|
||||
) -> Decimal:
|
||||
"""Calculate cost for TTS usage"""
|
||||
pricing_model = self.get_pricing_model("tts", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(character_count)
|
||||
|
||||
def calculate_stt_cost(self, provider: str, model: str, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT usage"""
|
||||
pricing_model = self.get_pricing_model("stt", provider, model)
|
||||
if not pricing_model:
|
||||
return Decimal("0")
|
||||
return pricing_model.calculate_cost(seconds)
|
||||
|
||||
def calculate_total_cost(self, usage_info: Dict) -> Dict[str, Any]:
|
||||
llm_cost_total = Decimal("0")
|
||||
tts_cost_total = Decimal("0")
|
||||
stt_cost_total = Decimal("0")
|
||||
|
||||
# Calculate LLM costs
|
||||
llm_usage = usage_info.get("llm", {})
|
||||
for key, usage in llm_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Try to determine provider from processor name or model
|
||||
provider = self._infer_provider_from_model(model, "llm")
|
||||
cost = self.calculate_llm_cost(provider, model, usage)
|
||||
llm_cost_total += cost
|
||||
|
||||
# Calculate TTS costs
|
||||
tts_usage = usage_info.get("tts", {})
|
||||
for key, character_count in tts_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
# Handle the case where model is "None" - infer from processor
|
||||
if model.lower() in ["none", "null", ""]:
|
||||
provider = self._infer_provider_from_processor(processor, "tts")
|
||||
model = "default" # Use default model for the provider
|
||||
else:
|
||||
provider = self._infer_provider_from_model(model, "tts")
|
||||
cost = self.calculate_tts_cost(provider, model, character_count)
|
||||
tts_cost_total += cost
|
||||
|
||||
# Calculate STT costs from explicit stt usage
|
||||
stt_usage = usage_info.get("stt", {})
|
||||
for key, seconds in stt_usage.items():
|
||||
processor, model = self._parse_key(key)
|
||||
provider = self._infer_provider_from_model(model, "stt")
|
||||
cost = self.calculate_stt_cost(provider, model, seconds)
|
||||
stt_cost_total += cost
|
||||
|
||||
total_cost = llm_cost_total + tts_cost_total + stt_cost_total
|
||||
|
||||
return {
|
||||
"llm_cost": float(llm_cost_total),
|
||||
"tts_cost": float(tts_cost_total),
|
||||
"stt_cost": float(stt_cost_total),
|
||||
"total": float(total_cost),
|
||||
}
|
||||
|
||||
def _parse_key(self, key) -> Tuple[str, str]:
|
||||
"""Parse key which is in format 'processor|||model'"""
|
||||
if isinstance(key, str) and "|||" in key:
|
||||
parts = key.split("|||", 1)
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
# Fallback for backwards compatibility or malformed keys
|
||||
return str(key), "unknown"
|
||||
|
||||
def _infer_provider_from_model(self, model: str, service_type: str) -> str:
|
||||
"""Infer provider from model name"""
|
||||
if not model:
|
||||
return "unknown"
|
||||
|
||||
model_lower = model.lower()
|
||||
|
||||
# OpenAI models
|
||||
if any(keyword in model_lower for keyword in ["gpt", "whisper", "openai"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq models
|
||||
if any(keyword in model_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Elevenlabs models
|
||||
if any(keyword in model_lower for keyword in ["eleven"]):
|
||||
return ServiceProviders.ELEVENLABS
|
||||
|
||||
# Deepgram models
|
||||
if any(
|
||||
keyword in model_lower
|
||||
for keyword in ["deepgram", "nova", "phonecall", "general"]
|
||||
):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _infer_provider_from_processor(self, processor: str, service_type: str) -> str:
|
||||
"""Infer provider from processor name"""
|
||||
if not processor:
|
||||
return "unknown"
|
||||
|
||||
processor_lower = processor.lower()
|
||||
|
||||
# OpenAI processors
|
||||
if any(keyword in processor_lower for keyword in ["openai", "gpt"]):
|
||||
return ServiceProviders.OPENAI
|
||||
|
||||
# Groq processors
|
||||
if any(keyword in processor_lower for keyword in ["groq"]):
|
||||
return ServiceProviders.GROQ
|
||||
|
||||
# Deepgram processors
|
||||
if any(keyword in processor_lower for keyword in ["deepgram"]):
|
||||
return ServiceProviders.DEEPGRAM
|
||||
|
||||
# Default to first available provider for the service type
|
||||
service_providers = self.pricing_registry.get(service_type, {})
|
||||
if service_providers:
|
||||
return list(service_providers.keys())[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
def update_pricing(
|
||||
self, service_type: str, provider: str, model: str, pricing_model: PricingModel
|
||||
):
|
||||
"""Update pricing for a specific service/provider/model combination"""
|
||||
if service_type not in self.pricing_registry:
|
||||
self.pricing_registry[service_type] = {}
|
||||
if provider not in self.pricing_registry[service_type]:
|
||||
self.pricing_registry[service_type][provider] = {}
|
||||
self.pricing_registry[service_type][provider][model] = pricing_model
|
||||
|
||||
|
||||
# Global cost calculator instance
|
||||
cost_calculator = CostCalculator()
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
"""
|
||||
Embeddings pricing models for different providers.
|
||||
|
||||
Prices are per token for embedding models.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import PricingModel
|
||||
|
||||
|
||||
class EmbeddingPricingModel(PricingModel):
|
||||
"""Pricing model for token-based embedding services."""
|
||||
|
||||
def __init__(self, token_price: Decimal):
|
||||
"""Initialize with price per token.
|
||||
|
||||
Args:
|
||||
token_price: Cost per token for embedding
|
||||
"""
|
||||
self.token_price = token_price
|
||||
|
||||
def calculate_cost(self, token_count: int) -> Decimal:
|
||||
"""Calculate cost for embedding token usage."""
|
||||
return Decimal(token_count) * self.token_price
|
||||
|
||||
|
||||
# Embeddings pricing registry
|
||||
EMBEDDINGS_PRICING: Dict[str, Dict[str, EmbeddingPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"text-embedding-3-small": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.02") / 1_000_000, # $0.02 per 1M tokens
|
||||
),
|
||||
"text-embedding-3-large": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.13") / 1_000_000, # $0.13 per 1M tokens
|
||||
),
|
||||
"text-embedding-ada-002": EmbeddingPricingModel(
|
||||
token_price=Decimal("0.10") / 1_000_000, # $0.10 per 1M tokens (legacy)
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
"""
|
||||
LLM pricing models for different providers.
|
||||
|
||||
Prices are per 1000 tokens for most models, with some newer models priced per million tokens.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TokenPricingModel
|
||||
|
||||
# LLM pricing registry
|
||||
LLM_PRICING: Dict[str, Dict[str, TokenPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-3.5-turbo": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.0015") / 1000, # $0.0015 per 1K tokens
|
||||
completion_token_price=Decimal("0.002") / 1000, # $0.002 per 1K tokens
|
||||
),
|
||||
"gpt-4": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.03") / 1000, # $0.03 per 1K tokens
|
||||
completion_token_price=Decimal("0.06") / 1000, # $0.06 per 1K tokens
|
||||
),
|
||||
"gpt-4.1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.00") / 1000000, # $2.00 per 1M tokens
|
||||
completion_token_price=Decimal("8.00") / 1000000, # $8.00 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("1.60") / 1000000, # $1.60 per 1M tokens
|
||||
),
|
||||
"gpt-4.1-nano": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.10") / 1000000, # $0.10 per 1M tokens
|
||||
completion_token_price=Decimal("0.40") / 1000000, # $0.40 per 1M tokens
|
||||
),
|
||||
"gpt-4.5-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("75.00") / 1000000, # $75.00 per 1M tokens
|
||||
completion_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens - FIXED
|
||||
completion_token_price=Decimal("10.00")
|
||||
/ 1000000, # $10.00 per 1M tokens - FIXED
|
||||
),
|
||||
"gpt-4o-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("20.00") / 1000000, # $20.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-audio-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-realtime-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("2.40") / 1000000, # $2.40 per 1M tokens
|
||||
),
|
||||
"gpt-4o-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-search-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.15") / 1000000, # $0.15 per 1M tokens
|
||||
completion_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
),
|
||||
"o1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("15.00") / 1000000, # $15.00 per 1M tokens
|
||||
completion_token_price=Decimal("60.00") / 1000000, # $60.00 per 1M tokens
|
||||
),
|
||||
"o1-pro": TokenPricingModel(
|
||||
prompt_token_price=Decimal("150.00") / 1000000, # $150.00 per 1M tokens
|
||||
completion_token_price=Decimal("600.00") / 1000000, # $600.00 per 1M tokens
|
||||
),
|
||||
"o1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o3": TokenPricingModel(
|
||||
prompt_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
completion_token_price=Decimal("40.00") / 1000000, # $40.00 per 1M tokens
|
||||
),
|
||||
"o3-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"o4-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.10") / 1000000, # $1.10 per 1M tokens
|
||||
completion_token_price=Decimal("4.40") / 1000000, # $4.40 per 1M tokens
|
||||
),
|
||||
"computer-use-preview": TokenPricingModel(
|
||||
prompt_token_price=Decimal("3.00") / 1000000, # $3.00 per 1M tokens
|
||||
completion_token_price=Decimal("12.00") / 1000000, # $12.00 per 1M tokens
|
||||
),
|
||||
"gpt-image-1": TokenPricingModel(
|
||||
prompt_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
completion_token_price=Decimal("0") / 1000000, # No output pricing shown
|
||||
),
|
||||
"codex-mini-latest": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.50") / 1000000, # $1.50 per 1M tokens
|
||||
completion_token_price=Decimal("6.00") / 1000000, # $6.00 per 1M tokens
|
||||
),
|
||||
# Transcription models
|
||||
"gpt-4o-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("2.50") / 1000000, # $2.50 per 1M tokens
|
||||
completion_token_price=Decimal("10.00") / 1000000, # $10.00 per 1M tokens
|
||||
),
|
||||
"gpt-4o-mini-transcribe": TokenPricingModel(
|
||||
prompt_token_price=Decimal("1.25") / 1000000, # $1.25 per 1M tokens
|
||||
completion_token_price=Decimal("5.00") / 1000000, # $5.00 per 1M tokens
|
||||
),
|
||||
# TTS models with token-based pricing
|
||||
"gpt-4o-mini-tts": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.60") / 1000000, # $0.60 per 1M tokens
|
||||
completion_token_price=Decimal("0")
|
||||
/ 1000000, # No completion tokens for TTS
|
||||
),
|
||||
},
|
||||
ServiceProviders.GROQ: {
|
||||
"llama-3.3-70b-versatile": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # $0.00059 per 1K tokens
|
||||
completion_token_price=Decimal("0.00079") / 1000, # $0.00079 per 1K tokens
|
||||
),
|
||||
"deepseek-r1-distill-llama-70b": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.00059") / 1000, # Assuming similar pricing
|
||||
completion_token_price=Decimal("0.00079") / 1000,
|
||||
),
|
||||
},
|
||||
ServiceProviders.AZURE: {
|
||||
"gpt-4.1-mini": TokenPricingModel(
|
||||
prompt_token_price=Decimal("0.44") / 1000000, # $0.40 per 1M tokens
|
||||
completion_token_price=Decimal("8.80")
|
||||
/ 1000000, # $1.60 per 1M tokens if using data zone
|
||||
)
|
||||
},
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
"""
|
||||
Base pricing models for different service types.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CostType(Enum):
|
||||
LLM_TOKENS = "llm_tokens"
|
||||
TTS_CHARACTERS = "tts_characters"
|
||||
STT_SECONDS = "stt_seconds"
|
||||
|
||||
|
||||
class PricingModel:
|
||||
"""Base class for pricing models"""
|
||||
|
||||
def calculate_cost(self, usage: Any) -> Decimal:
|
||||
"""Calculate cost based on usage"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenPricingModel(PricingModel):
|
||||
"""Pricing model for token-based services (LLM)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_price: Decimal,
|
||||
completion_token_price: Decimal,
|
||||
cache_read_discount: Decimal = Decimal("0.5"), # 50% discount for cache reads
|
||||
cache_creation_multiplier: Decimal = Decimal(
|
||||
"1.25"
|
||||
), # 25% premium for cache creation
|
||||
):
|
||||
self.prompt_token_price = prompt_token_price
|
||||
self.completion_token_price = completion_token_price
|
||||
self.cache_read_discount = cache_read_discount
|
||||
self.cache_creation_multiplier = cache_creation_multiplier
|
||||
|
||||
def calculate_cost(self, usage: Dict[str, int]) -> Decimal:
|
||||
"""Calculate cost for LLM token usage"""
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
cache_read_tokens = usage.get("cache_read_input_tokens") or 0
|
||||
cache_creation_tokens = usage.get("cache_creation_input_tokens") or 0
|
||||
|
||||
# Base cost
|
||||
prompt_cost = Decimal(prompt_tokens) * self.prompt_token_price
|
||||
completion_cost = Decimal(completion_tokens) * self.completion_token_price
|
||||
|
||||
# Cache adjustments
|
||||
cache_read_savings = (
|
||||
Decimal(cache_read_tokens)
|
||||
* self.prompt_token_price
|
||||
* self.cache_read_discount
|
||||
)
|
||||
cache_creation_premium = (
|
||||
Decimal(cache_creation_tokens)
|
||||
* self.prompt_token_price
|
||||
* (self.cache_creation_multiplier - 1)
|
||||
)
|
||||
|
||||
total_cost = (
|
||||
prompt_cost + completion_cost - cache_read_savings + cache_creation_premium
|
||||
)
|
||||
return max(total_cost, Decimal("0")) # Ensure non-negative
|
||||
|
||||
|
||||
class CharacterPricingModel(PricingModel):
|
||||
"""Pricing model for character-based services (TTS)"""
|
||||
|
||||
def __init__(self, character_price: Decimal):
|
||||
self.character_price = character_price
|
||||
|
||||
def calculate_cost(self, character_count: int) -> Decimal:
|
||||
"""Calculate cost for TTS character usage"""
|
||||
return Decimal(character_count) * self.character_price
|
||||
|
||||
|
||||
class TimePricingModel(PricingModel):
|
||||
"""Pricing model for time-based services (STT)"""
|
||||
|
||||
def __init__(self, second_price: Decimal):
|
||||
self.second_price = second_price
|
||||
|
||||
def calculate_cost(self, seconds: float) -> Decimal:
|
||||
"""Calculate cost for STT time usage"""
|
||||
return Decimal(str(seconds)) * self.second_price
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Main pricing registry that combines all service type pricing models.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .embeddings import EMBEDDINGS_PRICING
|
||||
from .llm import LLM_PRICING
|
||||
from .stt import STT_PRICING
|
||||
from .tts import TTS_PRICING
|
||||
|
||||
# Combined pricing registry for all service types
|
||||
PRICING_REGISTRY: Dict = {
|
||||
"llm": LLM_PRICING,
|
||||
"tts": TTS_PRICING,
|
||||
"stt": STT_PRICING,
|
||||
"embeddings": EMBEDDINGS_PRICING,
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""
|
||||
STT (Speech-to-Text) pricing models for different providers.
|
||||
|
||||
Prices are per second for STT services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import TimePricingModel
|
||||
|
||||
# STT pricing registry
|
||||
STT_PRICING: Dict[str, Dict[str, TimePricingModel]] = {
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"nova-3-general": TimePricingModel(Decimal("0.0077") / 60),
|
||||
"nova-2": TimePricingModel(Decimal("0.0058") / 60),
|
||||
"default": TimePricingModel(Decimal("0.0077") / 60),
|
||||
},
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-transcribe": TimePricingModel(Decimal("0.015") / 60),
|
||||
"default": TimePricingModel(Decimal("0.015") / 60),
|
||||
},
|
||||
"default": {"default": TimePricingModel(Decimal("0.0077") / 60)},
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""
|
||||
TTS (Text-to-Speech) pricing models for different providers.
|
||||
|
||||
Prices are per character for TTS services.
|
||||
"""
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
|
||||
from .models import CharacterPricingModel
|
||||
|
||||
# TTS pricing registry
|
||||
TTS_PRICING: Dict[str, Dict[str, CharacterPricingModel]] = {
|
||||
ServiceProviders.OPENAI: {
|
||||
"gpt-4o-mini-tts": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
"default": CharacterPricingModel(Decimal("0.6") / 1_00_00_000),
|
||||
},
|
||||
ServiceProviders.DEEPGRAM: {
|
||||
"aura-2": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
"aura-1": CharacterPricingModel(Decimal("0.015") / 1_000),
|
||||
"default": CharacterPricingModel(Decimal("0.030") / 1_000),
|
||||
},
|
||||
ServiceProviders.ELEVENLABS: {
|
||||
# 6400 usd per 250*1e6 characters
|
||||
"default": CharacterPricingModel(Decimal("0.0256") / 1_000)
|
||||
},
|
||||
"default": {"default": CharacterPricingModel(Decimal("0.030") / 1_000)},
|
||||
}
|
||||
|
|
@ -1,230 +0,0 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.pricing.cost_calculator import cost_calculator
|
||||
from api.services.telephony.factory import get_telephony_provider_for_run
|
||||
|
||||
|
||||
async def _fetch_telephony_cost(workflow_run) -> dict | None:
|
||||
"""Fetch telephony call cost. Returns a dict with cost_usd and provider_name, or None."""
|
||||
if (
|
||||
workflow_run.mode
|
||||
not in [WorkflowRunMode.TWILIO.value, WorkflowRunMode.VONAGE.value]
|
||||
or not workflow_run.cost_info
|
||||
):
|
||||
return None
|
||||
|
||||
call_id = workflow_run.cost_info.get("call_id")
|
||||
if not call_id:
|
||||
logger.warning(f"call_id not found in cost_info")
|
||||
return None
|
||||
|
||||
provider_name = workflow_run.mode.lower() if workflow_run.mode else ""
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning("Workflow not found for workflow run")
|
||||
raise Exception("Workflow not found")
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
call_cost_info = await provider.get_call_cost(call_id)
|
||||
|
||||
if call_cost_info.get("status") == "error":
|
||||
logger.error(
|
||||
f"Failed to fetch {provider_name} call cost: {call_cost_info.get('error')}"
|
||||
)
|
||||
return None
|
||||
|
||||
cost_usd = call_cost_info.get("cost_usd", 0.0)
|
||||
logger.info(
|
||||
f"{provider_name.title()} call cost: ${cost_usd:.6f} USD for call {call_id}"
|
||||
)
|
||||
return {"cost_usd": cost_usd, "provider_name": provider_name}
|
||||
|
||||
|
||||
async def _update_organization_usage(
|
||||
org, dograh_tokens: float, duration_seconds: float, charge_usd: float | None
|
||||
) -> None:
|
||||
"""Update organization usage after a workflow run."""
|
||||
org_id = org.id
|
||||
await db_client.update_usage_after_run(
|
||||
org_id, dograh_tokens, duration_seconds, charge_usd
|
||||
)
|
||||
if charge_usd is not None:
|
||||
logger.info(
|
||||
f"Updated organization usage with ${charge_usd:.2f} USD ({dograh_tokens} Dograh Tokens) and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Updated organization usage with {dograh_tokens} Dograh Tokens and {duration_seconds}s duration for org {org_id}"
|
||||
)
|
||||
|
||||
|
||||
async def _get_pricing_organization(workflow_run):
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
organization_id = getattr(workflow, "organization_id", None)
|
||||
if organization_id is None and workflow and workflow.user:
|
||||
organization_id = workflow.user.selected_organization_id
|
||||
if organization_id is None:
|
||||
return None
|
||||
return await db_client.get_organization_by_id(organization_id)
|
||||
|
||||
|
||||
async def _build_usage_cost_snapshot(
|
||||
usage_info: dict | None,
|
||||
*,
|
||||
workflow_run=None,
|
||||
include_telephony_cost: bool = False,
|
||||
organization=None,
|
||||
calculated_at: str | None = None,
|
||||
) -> dict | None:
|
||||
if not usage_info:
|
||||
logger.warning("No usage info available for workflow run")
|
||||
return None
|
||||
|
||||
cost_breakdown = cost_calculator.calculate_total_cost(usage_info)
|
||||
|
||||
if include_telephony_cost and workflow_run is not None:
|
||||
try:
|
||||
telephony_cost = await _fetch_telephony_cost(workflow_run)
|
||||
if telephony_cost:
|
||||
telephony_cost_usd = telephony_cost["cost_usd"]
|
||||
provider_name = telephony_cost["provider_name"]
|
||||
cost_breakdown["telephony_call"] = telephony_cost_usd
|
||||
cost_breakdown[f"{provider_name}_call"] = telephony_cost_usd
|
||||
cost_breakdown["total"] = (
|
||||
float(cost_breakdown["total"]) + telephony_cost_usd
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch telephony call cost: {e}")
|
||||
# Don't fail the whole cost calculation if telephony API fails
|
||||
|
||||
total_cost_usd = Decimal(str(cost_breakdown["total"]))
|
||||
dograh_tokens = float(total_cost_usd * Decimal("100"))
|
||||
|
||||
if organization is None and workflow_run is not None:
|
||||
organization = await _get_pricing_organization(workflow_run)
|
||||
|
||||
charge_usd = None
|
||||
if organization and organization.price_per_second_usd:
|
||||
duration_seconds = usage_info.get("call_duration_seconds", 0)
|
||||
charge_usd = float(
|
||||
Decimal(str(duration_seconds))
|
||||
* Decimal(str(organization.price_per_second_usd))
|
||||
)
|
||||
|
||||
cost_info = {
|
||||
"cost_breakdown": cost_breakdown,
|
||||
"total_cost_usd": float(total_cost_usd),
|
||||
"dograh_token_usage": dograh_tokens,
|
||||
"calculated_at": calculated_at
|
||||
or (workflow_run.created_at.isoformat() if workflow_run is not None else None),
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds", 0),
|
||||
}
|
||||
|
||||
if charge_usd is not None:
|
||||
cost_info["charge_usd"] = charge_usd
|
||||
cost_info["price_per_second_usd"] = organization.price_per_second_usd
|
||||
|
||||
return cost_info
|
||||
|
||||
|
||||
async def build_workflow_run_cost_info(workflow_run) -> dict | None:
|
||||
cost_info = await _build_usage_cost_snapshot(
|
||||
workflow_run.usage_info,
|
||||
workflow_run=workflow_run,
|
||||
include_telephony_cost=True,
|
||||
calculated_at=workflow_run.created_at.isoformat(),
|
||||
)
|
||||
if cost_info is None:
|
||||
return None
|
||||
return {
|
||||
**(workflow_run.cost_info or {}),
|
||||
**cost_info,
|
||||
}
|
||||
|
||||
|
||||
async def save_workflow_run_cost_info(
|
||||
workflow_run_id: int, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
await db_client.update_workflow_run(run_id=workflow_run_id, cost_info=cost_info)
|
||||
|
||||
|
||||
async def apply_workflow_run_usage_to_organization(
|
||||
workflow_run, cost_info: dict | None
|
||||
) -> None:
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
|
||||
|
||||
async def apply_usage_delta_to_organization(
|
||||
workflow_run, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if not org:
|
||||
return None
|
||||
|
||||
cost_info = await _build_usage_cost_snapshot(usage_info, organization=org)
|
||||
if cost_info is None:
|
||||
return None
|
||||
|
||||
await _update_organization_usage(
|
||||
org,
|
||||
float(cost_info.get("dograh_token_usage") or 0),
|
||||
float(cost_info.get("call_duration_seconds") or 0),
|
||||
cost_info.get("charge_usd"),
|
||||
)
|
||||
return cost_info
|
||||
|
||||
|
||||
async def calculate_workflow_run_cost(workflow_run_id: int):
|
||||
logger.debug("Calculating cost for workflow run")
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning("Workflow run not found")
|
||||
return
|
||||
|
||||
try:
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is None:
|
||||
return
|
||||
|
||||
await save_workflow_run_cost_info(workflow_run_id, cost_info)
|
||||
|
||||
try:
|
||||
await apply_workflow_run_usage_to_organization(workflow_run, cost_info)
|
||||
except Exception as e:
|
||||
org = await _get_pricing_organization(workflow_run)
|
||||
if org:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for org {org.id}: {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to update organization usage: {e}")
|
||||
# Don't fail the whole cost calculation if usage update fails
|
||||
|
||||
logger.info(
|
||||
f"Calculated cost for workflow run: ${cost_info['total_cost_usd']:.6f} USD ({cost_info['dograh_token_usage']} Dograh Tokens)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost for workflow run: {e}")
|
||||
raise
|
||||
|
|
@ -5,15 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
get_dograh_service_api_key,
|
||||
uses_managed_model_services_v2,
|
||||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
|
||||
|
||||
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your Dograh credits. "
|
||||
"Please purchase more credits from /billing "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
|
|
@ -24,104 +47,359 @@ class QuotaCheckResult:
|
|||
error_code: str = ""
|
||||
|
||||
|
||||
async def check_dograh_quota(
|
||||
user: UserModel, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check if user has sufficient Dograh quota for making a call.
|
||||
|
||||
This function checks if the user is using any Dograh services (LLM, STT, TTS)
|
||||
and validates that they have sufficient credits remaining.
|
||||
|
||||
When ``workflow_id`` is provided, the workflow's per-workflow
|
||||
``model_overrides`` are merged onto the user's global config so the quota
|
||||
check runs against the credentials that will actually be used for the call
|
||||
(rather than always falling back to the user's defaults).
|
||||
|
||||
Args:
|
||||
user: The user to check quota for
|
||||
workflow_id: Optional workflow whose ``model_overrides`` should be
|
||||
applied when resolving the effective service config.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with has_quota=True if user has sufficient quota or
|
||||
is not using Dograh services, or has_quota=False with error_message
|
||||
if quota is insufficient.
|
||||
"""
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
model_overrides = (workflow.workflow_configurations or {}).get(
|
||||
"model_overrides"
|
||||
|
||||
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="insufficient_credits",
|
||||
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _service_uses_dograh(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _dograh_api_keys(user_config: Any) -> set[str]:
|
||||
api_keys: set[str] = set()
|
||||
for section_name in ("llm", "stt", "tts", "embeddings"):
|
||||
service = getattr(user_config, section_name, None)
|
||||
if not _service_uses_dograh(service):
|
||||
continue
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
all_api_keys = [
|
||||
api_key
|
||||
for api_key in service.get_all_api_keys()
|
||||
if isinstance(api_key, str) and api_key
|
||||
]
|
||||
if all_api_keys:
|
||||
api_keys.update(all_api_keys)
|
||||
continue
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if api_key:
|
||||
api_keys.add(api_key)
|
||||
return api_keys
|
||||
|
||||
|
||||
async def _store_run_correlation_id(
|
||||
workflow_run_id: int | None,
|
||||
correlation_id: str | None,
|
||||
) -> None:
|
||||
if not workflow_run_id or not correlation_id:
|
||||
return
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Could not store MPS correlation id for missing workflow run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
initial_context = dict(workflow_run.initial_context or {})
|
||||
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
|
||||
return
|
||||
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
|
||||
async def _authorize_hosted_workflow_run_start(
|
||||
*,
|
||||
workflow_owner: UserModel,
|
||||
organization_id: int | None,
|
||||
workflow_id: int | None,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> tuple[QuotaCheckResult, bool]:
|
||||
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
|
||||
if DEPLOYMENT_MODE == "oss" or organization_id is None:
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
requires_correlation = bool(
|
||||
workflow_run_id and uses_managed_model_services_v2(user_config)
|
||||
)
|
||||
service_key = (
|
||||
get_dograh_service_api_key(user_config) if requires_correlation else None
|
||||
)
|
||||
if requires_correlation and not service_key:
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
try:
|
||||
authorization = await mps_service_key_client.authorize_workflow_run_start(
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
service_key=service_key,
|
||||
require_correlation_id=requires_correlation,
|
||||
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by=(
|
||||
str(workflow_owner.provider_id)
|
||||
if workflow_owner.provider_id is not None
|
||||
else None
|
||||
),
|
||||
metadata={
|
||||
"dograh_user_id": str(workflow_owner.id),
|
||||
"workflow_id": workflow_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize workflow start with MPS for org {}: {}",
|
||||
organization_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
billing_mode = authorization.get("billing_mode")
|
||||
if billing_mode != "v2":
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
remaining = _safe_float(authorization.get("remaining_credits"))
|
||||
if (
|
||||
not authorization.get("allowed", False)
|
||||
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
|
||||
):
|
||||
logger.warning(
|
||||
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return _insufficient_billing_v2_quota_result(), True
|
||||
|
||||
try:
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
authorization.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to store MPS correlation id for workflow_run_id {}: {}",
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
logger.info(
|
||||
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return QuotaCheckResult(has_quota=True), True
|
||||
|
||||
|
||||
async def _authorize_legacy_dograh_keys(
|
||||
*,
|
||||
dograh_api_keys: set[str],
|
||||
organization_id: int | None,
|
||||
workflow_owner: UserModel,
|
||||
) -> QuotaCheckResult:
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=workflow_owner.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
if model_overrides:
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
return _insufficient_legacy_quota_result()
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
dograh_api_keys = set()
|
||||
|
||||
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.llm.api_key)
|
||||
|
||||
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.stt.api_key)
|
||||
|
||||
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
# Check quota for ALL Dograh keys
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < 0.10:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=(
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def _authorize_oss_managed_v2_correlation(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> QuotaCheckResult:
|
||||
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
service_key = get_dograh_service_api_key(user_config)
|
||||
if not service_key:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
response.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
actor_user: UserModel | None = None,
|
||||
) -> QuotaCheckResult:
|
||||
"""Authorize a workflow run before any billable call/text runtime starts.
|
||||
|
||||
The workflow organization is the billing subject for hosted v2. The workflow
|
||||
owner is used only to resolve the effective model configuration and legacy
|
||||
service-key metadata.
|
||||
"""
|
||||
try:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if not workflow:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
actor_org_id = getattr(actor_user, "selected_organization_id", None)
|
||||
if actor_org_id is not None and actor_org_id != workflow.organization_id:
|
||||
logger.warning(
|
||||
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
|
||||
actor_org_id,
|
||||
workflow_id,
|
||||
workflow.organization_id,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
|
||||
if not workflow_owner:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="user_not_found",
|
||||
error_message="User not found",
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_owner.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=workflow.workflow_configurations,
|
||||
)
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
|
||||
workflow_owner=workflow_owner,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
if hosted_enforced or not hosted_result.has_quota:
|
||||
return hosted_result
|
||||
|
||||
dograh_api_keys = _dograh_api_keys(user_config)
|
||||
if not dograh_api_keys:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
legacy_result = await _authorize_legacy_dograh_keys(
|
||||
dograh_api_keys=dograh_api_keys,
|
||||
organization_id=(
|
||||
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
|
||||
),
|
||||
workflow_owner=workflow_owner,
|
||||
)
|
||||
if not legacy_result.has_quota:
|
||||
return legacy_result
|
||||
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return await _authorize_oss_managed_v2_correlation(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
|
@ -129,30 +407,3 @@ async def check_dograh_quota(
|
|||
logger.error(f"Error during quota check: {str(e)}")
|
||||
# On unexpected error, allow the call to proceed
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def check_dograh_quota_by_user_id(
|
||||
user_id: int, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check Dograh quota by user ID.
|
||||
|
||||
Convenience function that fetches the user and then checks quota. When
|
||||
``workflow_id`` is provided, the workflow's ``model_overrides`` are
|
||||
applied so the quota check evaluates the credentials that will actually
|
||||
be used for the call.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check quota for
|
||||
workflow_id: Optional workflow whose per-workflow overrides should
|
||||
be applied to the user's config before checking quota.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with quota status
|
||||
"""
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message="User not found",
|
||||
)
|
||||
return await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
for run in runs:
|
||||
initial = run.initial_context or {}
|
||||
gathered = run.gathered_context or {}
|
||||
cost = run.cost_info or {}
|
||||
usage = run.usage_info or {}
|
||||
|
||||
call_tags = gathered.get("call_tags", [])
|
||||
if isinstance(call_tags, list):
|
||||
|
|
@ -67,7 +67,7 @@ def build_run_report_csv(runs: List[Any]) -> io.StringIO:
|
|||
run.created_at.isoformat() if run.created_at else "",
|
||||
initial.get("phone_number", ""),
|
||||
gathered.get("mapped_call_disposition", ""),
|
||||
cost.get("call_duration_seconds", ""),
|
||||
usage.get("call_duration_seconds", ""),
|
||||
]
|
||||
|
||||
extracted = gathered.get("extracted_variables", {})
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from loguru import logger
|
|||
from api.constants import REDIS_URL
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunMode
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.transfer_event_protocol import (
|
||||
TransferEvent,
|
||||
|
|
@ -564,19 +564,7 @@ class ARIConnection:
|
|||
|
||||
user_id = workflow.user_id
|
||||
|
||||
# 3. Check quota (apply per-workflow model_overrides).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=inbound_workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 4. Create workflow run
|
||||
# 3. Create workflow run
|
||||
call_id = channel_id
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=f"ARI Inbound {caller_number}",
|
||||
|
|
@ -602,6 +590,20 @@ class ARIConnection:
|
|||
f"(caller={caller_number}, called={called_number})"
|
||||
)
|
||||
|
||||
# 4. Check quota after the run exists so hosted v2 can mint and
|
||||
# store the MPS correlation id before the pipeline starts.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=inbound_workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 5. Answer the inbound channel
|
||||
await self._answer_channel(channel_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,8 @@ async def handle_cloudonix_cdr(request: Request):
|
|||
return {"status": "error", "message": "Missing domain field"}
|
||||
|
||||
# Extract call_id to find workflow run
|
||||
call_id = cdr_data.get("session").get("token")
|
||||
session = cdr_data.get("session")
|
||||
call_id = session.get("token") if isinstance(session, dict) else None
|
||||
logger.info(f"Cloudonix CDR data for call id {call_id} - {cdr_data}")
|
||||
if not call_id:
|
||||
logger.warning("Cloudonix CDR missing call_id field")
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ provider registry — see ProviderSpec.router.
|
|||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Header, Request
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from loguru import logger
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
from starlette.responses import HTMLResponse
|
||||
|
|
@ -29,6 +28,30 @@ from api.utils.telephony_helper import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
async def _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url: str,
|
||||
callback_data: dict,
|
||||
headers: dict,
|
||||
raw_body: str,
|
||||
*,
|
||||
log_prefix: str,
|
||||
) -> None:
|
||||
"""Verify a Vobiz callback signature, failing closed.
|
||||
|
||||
Vobiz signs every callback, so a missing signature header is an invalid
|
||||
request — ``provider.verify_inbound_signature`` returns ``False`` for both
|
||||
missing and forged signatures. Reject with HTTP 403 (per Vobiz's
|
||||
callback-validation docs) so the caller never reaches status processing.
|
||||
"""
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url, callback_data, headers, raw_body
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(f"{log_prefix} Invalid or missing Vobiz callback signature")
|
||||
raise HTTPException(status_code=403, detail="Invalid webhook signature")
|
||||
|
||||
|
||||
@router.post("/vobiz-xml", include_in_schema=False)
|
||||
async def handle_vobiz_xml_webhook(
|
||||
workflow_id: int, user_id: int, workflow_run_id: int, organization_id: int
|
||||
|
|
@ -65,8 +88,6 @@ async def handle_vobiz_xml_webhook(
|
|||
async def handle_vobiz_hangup_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_vobiz_signature: Optional[str] = Header(None),
|
||||
x_vobiz_timestamp: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Vobiz hangup callback (sent when call ends).
|
||||
|
||||
|
|
@ -75,82 +96,23 @@ async def handle_vobiz_hangup_callback(
|
|||
"""
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Logging all headers and body to understand what Vobiz actually sends
|
||||
all_headers = dict(request.headers)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
# Parse the callback data from the raw body so signed webhooks can verify
|
||||
# the exact bytes Vobiz sent without draining the request stream first.
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
||||
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz hangup callback - Body: {json.dumps(callback_data)}"
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Received Vobiz hangup callback {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
# Verify signature if Vobiz provided any supported signature header.
|
||||
has_vobiz_signature = any(
|
||||
header in all_headers
|
||||
for header in (
|
||||
"x-vobiz-signature-v3",
|
||||
"x-vobiz-signature-ma-v3",
|
||||
"x-vobiz-signature-v2",
|
||||
"x-vobiz-signature-ma-v2",
|
||||
)
|
||||
)
|
||||
if has_vobiz_signature:
|
||||
# We need the workflow run to get organization for provider credentials
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_run_not_found"}
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_not_found"}
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Vobiz hangup callback signature"
|
||||
)
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
|
||||
logger.info(f"[run {workflow_run_id}] Vobiz hangup callback signature verified")
|
||||
else:
|
||||
# Get workflow run for processing (signature verification already got it if needed)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for Vobiz hangup callback"
|
||||
)
|
||||
return {"status": "ignored", "reason": "workflow_run_not_found"}
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(f"[run {workflow_run_id}] Workflow not found")
|
||||
|
|
@ -160,6 +122,21 @@ async def handle_vobiz_hangup_callback(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
|
||||
# before they can mutate call state.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
)
|
||||
await _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
log_prefix=f"[run {workflow_run_id}]",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[run {workflow_run_id}] Processing Vobiz hangup with provider: {provider.PROVIDER_NAME}"
|
||||
)
|
||||
|
|
@ -167,10 +144,6 @@ async def handle_vobiz_hangup_callback(
|
|||
# Parse the callback data into generic format
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
|
||||
logger.debug(
|
||||
f"[run {workflow_run_id}] Parsed Vobiz callback data: {json.dumps(parsed_data)}"
|
||||
)
|
||||
|
||||
# Create StatusCallbackRequest from parsed data
|
||||
status_update = StatusCallbackRequest(
|
||||
call_id=parsed_data["call_id"],
|
||||
|
|
@ -194,8 +167,6 @@ async def handle_vobiz_hangup_callback(
|
|||
async def handle_vobiz_ring_callback(
|
||||
workflow_run_id: int,
|
||||
request: Request,
|
||||
x_vobiz_signature: Optional[str] = Header(None),
|
||||
x_vobiz_timestamp: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Vobiz ring callback (sent when call starts ringing).
|
||||
|
||||
|
|
@ -204,84 +175,46 @@ async def handle_vobiz_ring_callback(
|
|||
"""
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Logging all headers and body to understand what Vobiz actually sends
|
||||
all_headers = dict(request.headers)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz ring callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
# Parse the callback data from the raw body so signed webhooks can verify
|
||||
# the exact bytes Vobiz sent without draining the request stream first.
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
||||
# TODO: Remove this debug logging after Vobiz team clarifies webhook authentication
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Vobiz ring callback - Body: {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Received Vobiz ring callback {json.dumps(callback_data)}"
|
||||
)
|
||||
|
||||
# Verify signature if Vobiz provided any supported signature header.
|
||||
has_vobiz_signature = any(
|
||||
header in all_headers
|
||||
for header in (
|
||||
"x-vobiz-signature-v3",
|
||||
"x-vobiz-signature-ma-v3",
|
||||
"x-vobiz-signature-v2",
|
||||
"x-vobiz-signature-ma-v2",
|
||||
)
|
||||
)
|
||||
if has_vobiz_signature:
|
||||
# We need the workflow run to get organization for provider credentials
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_run_not_found"}
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow not found for signature verification"
|
||||
)
|
||||
return {"status": "error", "reason": "workflow_not_found"}
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
)
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Invalid Vobiz ring callback signature"
|
||||
)
|
||||
return {"status": "error", "reason": "invalid_signature"}
|
||||
|
||||
logger.info(f"[run {workflow_run_id}] Vobiz ring callback signature verified")
|
||||
else:
|
||||
# Get workflow run for processing (signature verification already got it if needed)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
f"[run {workflow_run_id}] Workflow run not found for Vobiz ring callback"
|
||||
)
|
||||
return {"status": "ignored", "reason": "workflow_run_not_found"}
|
||||
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
logger.warning(f"[run {workflow_run_id}] Workflow not found")
|
||||
return {"status": "ignored", "reason": "workflow_not_found"}
|
||||
|
||||
provider = await get_telephony_provider_for_run(
|
||||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
# Fail closed: reject unsigned/forged ring callbacks before logging them.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
)
|
||||
await _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
log_prefix=f"[run {workflow_run_id}]",
|
||||
)
|
||||
|
||||
# Log the ringing event
|
||||
telephony_callback_logs = workflow_run.logs.get("telephony_status_callbacks", [])
|
||||
ring_log = {
|
||||
|
|
@ -308,15 +241,10 @@ async def handle_vobiz_ring_callback(
|
|||
async def handle_vobiz_hangup_callback_by_workflow(
|
||||
workflow_id: int,
|
||||
request: Request,
|
||||
x_vobiz_signature: Optional[str] = Header(None),
|
||||
x_vobiz_timestamp: Optional[str] = Header(None),
|
||||
):
|
||||
"""Handle Vobiz hangup callback with workflow_id - finds workflow run by call_id."""
|
||||
|
||||
all_headers = dict(request.headers)
|
||||
logger.info(
|
||||
f"[workflow {workflow_id}] Vobiz hangup callback - Headers: {json.dumps(all_headers)}"
|
||||
)
|
||||
|
||||
try:
|
||||
callback_data, raw_body = await parse_webhook_request(request)
|
||||
|
|
@ -364,35 +292,18 @@ async def handle_vobiz_hangup_callback_by_workflow(
|
|||
workflow_run, workflow.organization_id
|
||||
)
|
||||
|
||||
has_vobiz_signature = any(
|
||||
header in all_headers
|
||||
for header in (
|
||||
"x-vobiz-signature-v3",
|
||||
"x-vobiz-signature-ma-v3",
|
||||
"x-vobiz-signature-v2",
|
||||
"x-vobiz-signature-ma-v2",
|
||||
)
|
||||
# Fail closed: Vobiz signs every callback, so reject unsigned/forged ones
|
||||
# before they can mutate call state.
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
await _verify_vobiz_callback(
|
||||
provider,
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
log_prefix=f"[workflow {workflow_id}]",
|
||||
)
|
||||
if has_vobiz_signature:
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
|
||||
is_valid = await provider.verify_inbound_signature(
|
||||
webhook_url,
|
||||
callback_data,
|
||||
all_headers,
|
||||
raw_body,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"[workflow {workflow_id}] Invalid Vobiz hangup callback signature"
|
||||
)
|
||||
return {"status": "error", "message": "invalid_signature"}
|
||||
|
||||
logger.info(
|
||||
f"[workflow {workflow_id}] Vobiz hangup callback signature verified"
|
||||
)
|
||||
|
||||
try:
|
||||
parsed_data = provider.parse_status_callback(callback_data)
|
||||
|
|
|
|||
|
|
@ -66,34 +66,6 @@ async def handle_vonage_events(
|
|||
logger.error(f"[run {workflow_run_id}] Workflow run not found")
|
||||
return {"status": "error", "message": "Workflow run not found"}
|
||||
|
||||
# For a completed call that includes cost info, capture it immediately
|
||||
if event_data.get("status") == "completed":
|
||||
# Vonage sometimes includes price info in the webhook
|
||||
if "price" in event_data or "rate" in event_data:
|
||||
try:
|
||||
if workflow_run.cost_info:
|
||||
# Store immediate cost info if available
|
||||
cost_info = workflow_run.cost_info.copy()
|
||||
if "price" in event_data:
|
||||
cost_info["vonage_webhook_price"] = float(event_data["price"])
|
||||
if "rate" in event_data:
|
||||
cost_info["vonage_webhook_rate"] = float(event_data["rate"])
|
||||
if "duration" in event_data:
|
||||
cost_info["vonage_webhook_duration"] = int(
|
||||
event_data["duration"]
|
||||
)
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id, cost_info=cost_info
|
||||
)
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Captured Vonage cost info from webhook"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[run {workflow_run_id}] Failed to capture Vonage cost from webhook: {e}"
|
||||
)
|
||||
|
||||
# Get workflow and provider
|
||||
workflow = await db_client.get_workflow_by_id(workflow_run.workflow_id)
|
||||
if not workflow:
|
||||
|
|
|
|||
|
|
@ -114,11 +114,13 @@ class StatusCallbackRequest(BaseModel):
|
|||
"NOANSWER": "no-answer",
|
||||
}
|
||||
|
||||
disposition = data.get("disposition", "")
|
||||
disposition = data.get("disposition") or ""
|
||||
status = disposition_map.get(disposition.upper(), disposition.lower())
|
||||
session = data.get("session")
|
||||
call_id = session.get("token") if isinstance(session, dict) else ""
|
||||
|
||||
return cls(
|
||||
call_id=data.get("session").get("token"),
|
||||
call_id=call_id or "",
|
||||
status=status,
|
||||
from_number=data.get("from"),
|
||||
to_number=data.get("to"),
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import asyncio
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.mcp_tool_session import McpToolSession
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
|
|
@ -382,6 +383,9 @@ class PipecatEngine:
|
|||
embeddings_provider=self._embeddings_provider,
|
||||
embeddings_endpoint=self._embeddings_endpoint,
|
||||
embeddings_api_version=self._embeddings_api_version,
|
||||
correlation_id=self._call_context_vars.get(
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
),
|
||||
tracing_context=self._get_otel_context(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
|
@ -43,7 +42,7 @@ async def resolve_llm_config(
|
|||
async def resolve_user_llm_config(
|
||||
workflow_run: WorkflowRunModel,
|
||||
) -> tuple[str, str, str, dict]:
|
||||
"""Resolve the user's configured LLM (from UserConfiguration).
|
||||
"""Resolve the user's configured LLM (from EffectiveAIModelConfiguration).
|
||||
|
||||
Returns:
|
||||
(provider, model, api_key, service_kwargs) tuple
|
||||
|
|
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
|
|||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_configurations = {}
|
||||
if workflow_run.definition:
|
||||
workflow_configurations = (
|
||||
workflow_run.definition.workflow_configurations or {}
|
||||
)
|
||||
elif workflow_run.workflow:
|
||||
workflow_configurations = (
|
||||
workflow_run.workflow.workflow_configurations or {}
|
||||
)
|
||||
|
||||
user_configuration = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow_run.workflow.organization_id
|
||||
if workflow_run.workflow
|
||||
else None,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
|
|
|
|||
41
api/services/workflow/run_usage_response.py
Normal file
41
api/services/workflow/run_usage_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Format workflow run usage for public API responses."""
|
||||
|
||||
|
||||
def format_public_usage_info(usage_info: dict | None) -> dict | None:
|
||||
if not usage_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"llm": usage_info.get("llm") or {},
|
||||
"tts": usage_info.get("tts") or {},
|
||||
"stt": usage_info.get("stt") or {},
|
||||
"call_duration_seconds": usage_info.get("call_duration_seconds"),
|
||||
}
|
||||
|
||||
|
||||
def format_public_cost_info(
|
||||
cost_info: dict | None, usage_info: dict | None
|
||||
) -> dict | None:
|
||||
"""Return the legacy response shape without doing local cost accounting."""
|
||||
duration = None
|
||||
if usage_info and usage_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(usage_info.get("call_duration_seconds") or 0))
|
||||
elif cost_info and cost_info.get("call_duration_seconds") is not None:
|
||||
duration = int(round(cost_info.get("call_duration_seconds") or 0))
|
||||
|
||||
dograh_token_usage = 0
|
||||
if cost_info:
|
||||
if "dograh_token_usage" in cost_info:
|
||||
dograh_token_usage = cost_info.get("dograh_token_usage") or 0
|
||||
elif "total_cost_usd" in cost_info:
|
||||
dograh_token_usage = round(
|
||||
float(cost_info.get("total_cost_usd", 0)) * 100, 2
|
||||
)
|
||||
|
||||
if duration is None and dograh_token_usage == 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"dograh_token_usage": dograh_token_usage,
|
||||
"call_duration_seconds": duration,
|
||||
}
|
||||
|
|
@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
|
|
@ -410,14 +409,31 @@ async def execute_text_chat_pending_turn(
|
|||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_run.workflow.user.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
||||
llm = create_llm_service(user_config)
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
ensure_mps_correlation_id,
|
||||
)
|
||||
|
||||
base_initial_context = dict(workflow_run.initial_context or {})
|
||||
mps_correlation_id = await ensure_mps_correlation_id(
|
||||
ai_model_config=user_config,
|
||||
workflow_run_id=workflow_run_id,
|
||||
initial_context=base_initial_context,
|
||||
)
|
||||
|
||||
llm = create_llm_service(user_config, correlation_id=mps_correlation_id)
|
||||
inference_llm = llm
|
||||
|
||||
runtime_configuration = {
|
||||
|
|
@ -425,9 +441,15 @@ async def execute_text_chat_pending_turn(
|
|||
"llm_model": user_config.llm.model,
|
||||
}
|
||||
initial_context = {
|
||||
**(workflow_run.initial_context or {}),
|
||||
**base_initial_context,
|
||||
"runtime_configuration": runtime_configuration,
|
||||
}
|
||||
if mps_correlation_id:
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = mps_correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
workflow_graph = WorkflowGraph(
|
||||
ReactFlowDTO.model_validate(run_definition.workflow_json)
|
||||
|
|
@ -466,9 +488,17 @@ async def execute_text_chat_pending_turn(
|
|||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
|
|
@ -606,8 +636,10 @@ async def execute_text_chat_pending_turn(
|
|||
"Transportless text chat pipeline failed while closing run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
raise
|
||||
await engine.close_mcp_sessions()
|
||||
await engine.cleanup()
|
||||
|
||||
gathered_context = await engine.get_gathered_context()
|
||||
|
|
|
|||
|
|
@ -4,17 +4,11 @@ from datetime import UTC, datetime
|
|||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunTextSessionModel
|
||||
from api.db.workflow_run_text_session_client import (
|
||||
WorkflowRunTextSessionRevisionConflictError,
|
||||
)
|
||||
from api.services.pricing.workflow_run_cost import (
|
||||
apply_usage_delta_to_organization,
|
||||
build_workflow_run_cost_info,
|
||||
)
|
||||
from api.services.workflow.text_chat_logs import (
|
||||
build_text_chat_realtime_feedback_events,
|
||||
)
|
||||
|
|
@ -261,20 +255,6 @@ async def execute_pending_text_chat_turn(
|
|||
state=execution.state,
|
||||
is_completed=execution.is_completed,
|
||||
)
|
||||
workflow_run = await db_client.get_workflow_run_by_id(run_id)
|
||||
if workflow_run:
|
||||
try:
|
||||
# Apply the per-turn delta so org usage tracks cumulative run cost
|
||||
# without replaying the full session totals on every turn.
|
||||
await apply_usage_delta_to_organization(workflow_run, execution.usage)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update organization usage for text chat run {run_id}: {e}"
|
||||
)
|
||||
|
||||
cost_info = await build_workflow_run_cost_info(workflow_run)
|
||||
if cost_info is not None:
|
||||
await db_client.update_workflow_run(run_id, cost_info=cost_info)
|
||||
|
||||
return await _reload_text_chat_session(run_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
tracing_context=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Retrieve relevant information from the knowledge base using vector similarity search.
|
||||
|
|
@ -75,6 +76,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Create span with parent context
|
||||
|
|
@ -115,6 +117,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
# Add result metadata to span
|
||||
|
|
@ -192,6 +195,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
else:
|
||||
# Tracing is disabled - perform retrieval without tracing
|
||||
|
|
@ -206,6 +210,7 @@ async def retrieve_from_knowledge_base(
|
|||
embeddings_provider,
|
||||
embeddings_endpoint,
|
||||
embeddings_api_version,
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -220,6 +225,7 @@ async def _perform_retrieval(
|
|||
embeddings_provider: Optional[str] = None,
|
||||
embeddings_endpoint: Optional[str] = None,
|
||||
embeddings_api_version: Optional[str] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal function to perform the actual retrieval operation.
|
||||
|
||||
|
|
@ -272,11 +278,20 @@ async def _perform_retrieval(
|
|||
api_version=embeddings_api_version or "2024-02-15-preview",
|
||||
)
|
||||
else:
|
||||
default_headers = None
|
||||
if (
|
||||
embeddings_provider == ServiceProviders.DOGRAH.value
|
||||
and correlation_id
|
||||
):
|
||||
default_headers = {
|
||||
"X-Dograh-Correlation-Id": correlation_id,
|
||||
}
|
||||
embedding_service = OpenAIEmbeddingService(
|
||||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=embeddings_base_url,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
results = await embedding_service.search_similar_chunks(
|
||||
|
|
|
|||
111
api/services/workflow_run_billing.py
Normal file
111
api/services/workflow_run_billing.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Workflow-run billing hooks.
|
||||
|
||||
Dograh does not rate or deduct credits locally. MPS owns credit accounting.
|
||||
For hosted deployments, Dograh reports completed platform usage to MPS.
|
||||
When a server-minted MPS correlation id exists, MPS uses model-service usage
|
||||
as the canonical duration. Otherwise Dograh reports the completed run duration.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.services.managed_model_services import get_mps_correlation_id
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
def _workflow_run_organization_id(workflow_run) -> int | None:
|
||||
workflow = getattr(workflow_run, "workflow", None)
|
||||
return getattr(workflow, "organization_id", None)
|
||||
|
||||
|
||||
def _duration_seconds_from_usage_info(workflow_run) -> float | None:
|
||||
usage_info: dict[str, Any] = getattr(workflow_run, "usage_info", None) or {}
|
||||
duration = usage_info.get("call_duration_seconds")
|
||||
try:
|
||||
duration_seconds = float(duration)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return duration_seconds if duration_seconds > 0 else None
|
||||
|
||||
|
||||
async def _organization_uses_mps_billing_v2(organization_id: int) -> bool:
|
||||
account = await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id
|
||||
)
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def report_workflow_run_platform_usage(workflow_run) -> None:
|
||||
"""Report hosted platform usage for a completed workflow run to MPS."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return
|
||||
|
||||
if not getattr(workflow_run, "is_completed", False):
|
||||
return
|
||||
|
||||
organization_id = _workflow_run_organization_id(workflow_run)
|
||||
if organization_id is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no organization_id",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
correlation_id = get_mps_correlation_id(
|
||||
getattr(workflow_run, "initial_context", None)
|
||||
)
|
||||
duration_seconds = (
|
||||
None if correlation_id else _duration_seconds_from_usage_info(workflow_run)
|
||||
)
|
||||
if not correlation_id and duration_seconds is None:
|
||||
logger.warning(
|
||||
"Skipping platform usage report for workflow run {}: no billable duration",
|
||||
workflow_run.id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
if not await _organization_uses_mps_billing_v2(organization_id):
|
||||
return
|
||||
|
||||
result = await mps_service_key_client.report_platform_usage(
|
||||
organization_id=organization_id,
|
||||
correlation_id=correlation_id,
|
||||
duration_seconds=duration_seconds,
|
||||
workflow_run_id=workflow_run.id,
|
||||
metadata={
|
||||
"source": "workflow_run_completion",
|
||||
"workflow_id": getattr(workflow_run, "workflow_id", None),
|
||||
"duration_source": (
|
||||
"mps_correlation" if correlation_id else "dograh_usage_info"
|
||||
),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Reported platform usage for workflow run {} to MPS: {}",
|
||||
workflow_run.id,
|
||||
result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to report platform usage for workflow run {}: {}",
|
||||
workflow_run.id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def report_completed_workflow_run_platform_usage(workflow_run_id: int) -> None:
|
||||
"""Load a completed workflow run and report platform usage to MPS."""
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Skipping platform usage report: workflow run {} not found",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
await report_workflow_run_platform_usage(workflow_run)
|
||||
Loading…
Add table
Add a link
Reference in a new issue