feat: centralise workflow run authorization

This commit is contained in:
Abhishek Kumar 2026-06-12 18:16:30 +05:30
parent 5bf7518829
commit 281656b960
21 changed files with 1036 additions and 252 deletions

View file

@ -15,6 +15,7 @@ from api.services.campaign.errors import (
PhoneNumberPoolExhaustedError,
)
from api.services.campaign.rate_limiter import rate_limiter
from api.services.quota_service import authorize_workflow_run_start
from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
},
)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
error_message = quota_result.error_message or "Quota exceeded"
logger.warning(
f"Campaign {campaign.id} quota check failed for workflow run "
f"{workflow_run.id}: {error_message}"
)
await db_client.update_workflow_run(
run_id=workflow_run.id,
is_completed=True,
state=WorkflowRunState.COMPLETED.value,
gathered_context={"error": error_message},
)
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
if mapping:
org_id, mapped_slot_id = mapping
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
workflow_run.id
)
if from_number_mapping:
fn_org_id, fn_number, fn_tcid = from_number_mapping
await rate_limiter.release_from_number(
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
)
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
raise ValueError(error_message)
# Initiate call via telephony provider
try:
# Construct webhook URL with parameters

View file

@ -2,11 +2,8 @@ from __future__ import annotations
from typing import Any
from loguru import logger
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
@ -48,27 +45,10 @@ async def ensure_mps_correlation_id(
if not uses_managed_model_services_v2(ai_model_config):
return None
service_key = _get_dograh_service_api_key(ai_model_config)
if not service_key:
raise ValueError(
"Managed model services v2 requires a Dograh service key before the run starts."
)
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
raise ValueError(
"Managed model services v2 requires workflow run authorization before "
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
)
correlation_id = response.get("correlation_id")
if not correlation_id:
raise ValueError("MPS correlation-id response did not include correlation_id")
correlation_id = str(correlation_id)
logger.info(
"Minted MPS correlation id {} for workflow run {}",
correlation_id,
workflow_run_id,
)
return correlation_id
def _is_dograh_service(service: Any) -> bool:
@ -78,7 +58,7 @@ def _is_dograh_service(service: Any) -> bool:
)
def _get_dograh_service_api_key(
def get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):

View file

@ -478,6 +478,50 @@ class MPSServiceKeyClient:
response=response,
)
async def authorize_workflow_run_start(
self,
*,
organization_id: int,
workflow_run_id: int | None = None,
service_key: Optional[str] = None,
require_correlation_id: bool = False,
minimum_credits: float | None = None,
metadata: Optional[dict] = None,
created_by: Optional[str] = None,
) -> dict:
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
payload = {
"workflow_run_id": workflow_run_id,
"service_key": service_key,
"require_correlation_id": require_correlation_id,
"metadata": metadata or {},
}
if minimum_credits is not None:
payload["minimum_credits"] = minimum_credits
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to authorize MPS workflow run start: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to authorize MPS workflow run start: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,

View file

@ -5,17 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
"""
from dataclasses import dataclass
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
from api.services.configuration.registry import ServiceProviders
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
get_dograh_service_api_key,
uses_managed_model_services_v2,
)
from api.services.mps_service_key_client import mps_service_key_client
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
)
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your Dograh credits. "
"Please purchase more credits from /billing "
"or change providers in Models configurations."
)
@dataclass
class QuotaCheckResult:
@ -26,116 +47,359 @@ class QuotaCheckResult:
error_code: str = ""
async def check_dograh_quota(
user: UserModel, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check if user has sufficient Dograh quota for making a call.
This function checks if the user is using any Dograh services (LLM, STT, TTS)
and validates that they have sufficient credits remaining.
When ``workflow_id`` is provided, the workflow's per-workflow
``model_overrides`` are merged onto the user's global config so the quota
check runs against the credentials that will actually be used for the call
(rather than always falling back to the user's defaults).
Args:
user: The user to check quota for
workflow_id: Optional workflow whose ``model_overrides`` should be
applied when resolving the effective service config.
Returns:
QuotaCheckResult with has_quota=True if user has sufficient quota or
is not using Dograh services, or has_quota=False with error_message
if quota is insufficient.
"""
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
organization_id = user.selected_organization_id
workflow_configurations = None
return float(value)
except (TypeError, ValueError):
return default
if workflow_id is not None:
workflow = await db_client.get_workflow_by_id(workflow_id)
if workflow:
organization_id = workflow.organization_id
workflow_configurations = workflow.workflow_configurations
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user.id,
organization_id=organization_id,
workflow_configurations=workflow_configurations,
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="insufficient_credits",
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
)
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
)
def _service_uses_dograh(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _dograh_api_keys(user_config: Any) -> set[str]:
api_keys: set[str] = set()
for section_name in ("llm", "stt", "tts", "embeddings"):
service = getattr(user_config, section_name, None)
if not _service_uses_dograh(service):
continue
if hasattr(service, "get_all_api_keys"):
all_api_keys = [
api_key
for api_key in service.get_all_api_keys()
if isinstance(api_key, str) and api_key
]
if all_api_keys:
api_keys.update(all_api_keys)
continue
api_key = getattr(service, "api_key", None)
if api_key:
api_keys.add(api_key)
return api_keys
async def _store_run_correlation_id(
workflow_run_id: int | None,
correlation_id: str | None,
) -> None:
if not workflow_run_id or not correlation_id:
return
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Could not store MPS correlation id for missing workflow run {}",
workflow_run_id,
)
return
initial_context = dict(workflow_run.initial_context or {})
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
return
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
async def _authorize_hosted_workflow_run_start(
*,
workflow_owner: UserModel,
organization_id: int | None,
workflow_id: int | None,
workflow_run_id: int | None,
user_config: Any,
) -> tuple[QuotaCheckResult, bool]:
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
if DEPLOYMENT_MODE == "oss" or organization_id is None:
return QuotaCheckResult(has_quota=True), False
requires_correlation = bool(
workflow_run_id and uses_managed_model_services_v2(user_config)
)
service_key = (
get_dograh_service_api_key(user_config) if requires_correlation else None
)
if requires_correlation and not service_key:
return (
QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
),
True,
)
# Check if user is using any Dograh service
using_dograh = False
dograh_api_keys = set()
try:
authorization = await mps_service_key_client.authorize_workflow_run_start(
organization_id=organization_id,
workflow_run_id=workflow_run_id,
service_key=service_key,
require_correlation_id=requires_correlation,
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by=(
str(workflow_owner.provider_id)
if workflow_owner.provider_id is not None
else None
),
metadata={
"dograh_user_id": str(workflow_owner.id),
"workflow_id": workflow_id,
},
)
except Exception as e:
logger.error(
"Failed to authorize workflow start with MPS for org {}: {}",
organization_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.llm.api_key)
billing_mode = authorization.get("billing_mode")
if billing_mode != "v2":
return QuotaCheckResult(has_quota=True), False
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.stt.api_key)
remaining = _safe_float(authorization.get("remaining_credits"))
if (
not authorization.get("allowed", False)
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
):
logger.warning(
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return _insufficient_billing_v2_quota_result(), True
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
try:
await _store_run_correlation_id(
workflow_run_id,
authorization.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to store MPS correlation id for workflow_run_id {}: {}",
workflow_run_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
logger.info(
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return QuotaCheckResult(has_quota=True), True
if (
user_config.embeddings
and user_config.embeddings.provider == ServiceProviders.DOGRAH
):
using_dograh = True
dograh_api_keys.add(user_config.embeddings.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return QuotaCheckResult(has_quota=True)
async def _authorize_legacy_dograh_keys(
*,
dograh_api_keys: set[str],
organization_id: int | None,
workflow_owner: UserModel,
) -> QuotaCheckResult:
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key,
organization_id=organization_id,
created_by=workflow_owner.provider_id,
)
remaining = usage.get("remaining_credits", 0.0)
# Check quota for ALL Dograh keys
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key,
organization_id=organization_id,
created_by=user.provider_id,
# Require at least $0.10 for a short call
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
remaining = usage.get("remaining_credits", 0.0)
return _insufficient_legacy_quota_result()
# Require at least $0.10 for a short call
if remaining < 0.10:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=(
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
),
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def _authorize_oss_managed_v2_correlation(
*,
workflow_id: int,
workflow_run_id: int | None,
user_config: Any,
) -> QuotaCheckResult:
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
return QuotaCheckResult(has_quota=True)
service_key = get_dograh_service_api_key(user_config)
if not service_key:
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
)
try:
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
await _store_run_correlation_id(
workflow_run_id,
response.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
workflow_id,
workflow_run_id,
e,
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def authorize_workflow_run_start(
*,
workflow_id: int,
workflow_run_id: int | None = None,
actor_user: UserModel | None = None,
) -> QuotaCheckResult:
"""Authorize a workflow run before any billable call/text runtime starts.
The workflow organization is the billing subject for hosted v2. The workflow
owner is used only to resolve the effective model configuration and legacy
service-key metadata.
"""
try:
workflow = await db_client.get_workflow_by_id(workflow_id)
if not workflow:
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
actor_org_id = getattr(actor_user, "selected_organization_id", None)
if actor_org_id is not None and actor_org_id != workflow.organization_id:
logger.warning(
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
actor_org_id,
workflow_id,
workflow.organization_id,
)
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
if not workflow_owner:
return QuotaCheckResult(
has_quota=False,
error_code="user_not_found",
error_message="User not found",
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_owner.id,
organization_id=workflow.organization_id,
workflow_configurations=workflow.workflow_configurations,
)
if DEPLOYMENT_MODE != "oss":
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
workflow_owner=workflow_owner,
organization_id=workflow.organization_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
if hosted_enforced or not hosted_result.has_quota:
return hosted_result
dograh_api_keys = _dograh_api_keys(user_config)
if not dograh_api_keys:
return QuotaCheckResult(has_quota=True)
legacy_result = await _authorize_legacy_dograh_keys(
dograh_api_keys=dograh_api_keys,
organization_id=(
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
),
workflow_owner=workflow_owner,
)
if not legacy_result.has_quota:
return legacy_result
if DEPLOYMENT_MODE == "oss":
return await _authorize_oss_managed_v2_correlation(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
return QuotaCheckResult(has_quota=True)
@ -143,30 +407,3 @@ async def check_dograh_quota(
logger.error(f"Error during quota check: {str(e)}")
# On unexpected error, allow the call to proceed
return QuotaCheckResult(has_quota=True)
async def check_dograh_quota_by_user_id(
user_id: int, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check Dograh quota by user ID.
Convenience function that fetches the user and then checks quota. When
``workflow_id`` is provided, the workflow's ``model_overrides`` are
applied so the quota check evaluates the credentials that will actually
be used for the call.
Args:
user_id: The ID of the user to check quota for
workflow_id: Optional workflow whose per-workflow overrides should
be applied to the user's config before checking quota.
Returns:
QuotaCheckResult with quota status
"""
user = await db_client.get_user_by_id(user_id)
if not user:
return QuotaCheckResult(
has_quota=False,
error_message="User not found",
)
return await check_dograh_quota(user, workflow_id=workflow_id)

View file

@ -26,7 +26,7 @@ from loguru import logger
from api.constants import REDIS_URL
from api.db import db_client
from api.enums import CallType, WorkflowRunMode
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.transfer_event_protocol import (
TransferEvent,
@ -564,19 +564,7 @@ class ARIConnection:
user_id = workflow.user_id
# 3. Check quota (apply per-workflow model_overrides).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=inbound_workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 4. Create workflow run
# 3. Create workflow run
call_id = channel_id
workflow_run = await db_client.create_workflow_run(
name=f"ARI Inbound {caller_number}",
@ -602,6 +590,20 @@ class ARIConnection:
f"(caller={caller_number}, called={called_number})"
)
# 4. Check quota after the run exists so hosted v2 can mint and
# store the MPS correlation id before the pipeline starts.
quota_result = await authorize_workflow_run_start(
workflow_id=inbound_workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 5. Answer the inbound channel
await self._answer_channel(channel_id)