feat: centralise workflow run authorization

This commit is contained in:
Abhishek Kumar 2026-06-12 18:16:30 +05:30
parent 5bf7518829
commit 281656b960
21 changed files with 1036 additions and 252 deletions

View file

@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect
from api.db import db_client
from api.enums import CallType, WorkflowRunState
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony import registry as telephony_registry
router = APIRouter(prefix="/agent-stream")
@ -67,19 +67,6 @@ async def agent_stream_websocket(
await websocket.close(code=1008, reason="Workflow not found")
return
quota_result = await check_dograh_quota_by_user_id(
workflow.user_id, workflow_id=workflow.id
)
if not quota_result.has_quota:
logger.warning(
f"agent-stream quota exceeded for user {workflow.user_id}: "
f"{quota_result.error_message}"
)
await websocket.close(
code=1008, reason=quota_result.error_message or "Quota exceeded"
)
return
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
workflow_run_name = f"WR-AGS-{numeric_suffix:08d}"
call_id = params.get("callId") or params.get("CallSid")
@ -108,6 +95,20 @@ async def agent_stream_websocket(
set_current_run_id(workflow_run.id)
set_current_org_id(workflow.organization_id)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"agent-stream quota exceeded for user {workflow.user_id}: "
f"{quota_result.error_message}"
)
await websocket.close(
code=1008, reason=quota_result.error_message or "Quota exceeded"
)
return
await db_client.update_workflow_run(
run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value
)

View file

@ -18,7 +18,7 @@ from api.services.auth.depends import get_user
from api.services.campaign.runner import campaign_runner_service
from api.services.campaign.source_sync import CampaignSourceSyncService
from api.services.campaign.source_sync_factory import get_sync_service
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
from api.services.reports import generate_campaign_report_csv
from api.services.storage import storage_fs
@ -550,7 +550,10 @@ async def start_campaign(
# Check Dograh quota before starting campaign (apply per-workflow
# model_overrides so we evaluate the keys this campaign will use).
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
@ -872,7 +875,10 @@ async def resume_campaign(
# Check Dograh quota before resuming campaign (apply per-workflow
# model_overrides so we evaluate the keys this campaign will use).
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel
from api.db import db_client
from api.enums import TriggerState, WorkflowStatus
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.factory import (
get_default_telephony_provider,
get_telephony_provider_by_id,
@ -179,14 +179,6 @@ async def _execute_resolved_target(
"""Shared execution path once the target workflow has been resolved."""
execution_user_id = _get_execution_user_id(target.workflow)
# Check Dograh quota using the workflow owner's config and model overrides.
quota_result = await check_dograh_quota_by_user_id(
execution_user_id,
workflow_id=target.workflow.id,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Get telephony provider — either the caller-specified config (validated
# against the workflow's org) or the org's default config.
if request.telephony_configuration_id is not None:
@ -268,6 +260,15 @@ async def _execute_resolved_target(
f"to phone number {request.phone_number}"
)
# Check Dograh quota after the run exists so hosted v2 can mint and store
# the MPS correlation id before the provider starts the call.
quota_result = await authorize_workflow_run_start(
workflow_id=target.workflow.id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# 9. Construct webhook URL for telephony provider callback
backend_endpoint, _ = await get_backend_endpoints()
webhook_endpoint = provider.WEBHOOK_ENDPOINT

View file

@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
from api.errors.telephony_errors import TelephonyError
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.factory import (
get_all_telephony_providers,
@ -136,14 +136,6 @@ async def initiate_call(
raise HTTPException(status_code=404, detail="Workflow not found")
execution_user_id = _get_execution_user_id(workflow)
# Check Dograh quota before initiating the call (apply per-workflow
# model_overrides so the keys we will actually use are the ones checked).
quota_result = await check_dograh_quota_by_user_id(
execution_user_id, workflow_id=workflow.id
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Determine the workflow run mode based on provider type
workflow_run_mode = provider.PROVIDER_NAME
@ -186,6 +178,16 @@ async def initiate_call(
)
workflow_run_name = workflow_run.name
# Check Dograh quota after the run exists so hosted v2 can mint and store
# the MPS correlation id before initiating the call.
quota_result = await authorize_workflow_run_start(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Construct webhook URL based on provider type
backend_endpoint, _ = await get_backend_endpoints()
@ -739,19 +741,8 @@ async def handle_inbound_run(request: Request):
TelephonyError.SIGNATURE_VALIDATION_FAILED
)
# 4. Quota check (use the workflow's model_overrides if set).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# 5. Create workflow run + return provider-shaped response.
# 5. Create workflow run + authorize quota before returning provider
# stream instructions.
workflow_run_id = await _create_inbound_workflow_run(
workflow_id,
user_id,
@ -760,6 +751,17 @@ async def handle_inbound_run(request: Request):
telephony_configuration_id=telephony_configuration_id,
from_phone_number_id=phone_row.id,
)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
websocket_url = (
@ -874,20 +876,8 @@ async def handle_inbound_telephony(
logger.error(f"Request validation failed: {error_type}")
return provider_class.generate_validation_error_response(error_type)
# Check quota before processing (apply per-workflow model_overrides).
# Create workflow run.
user_id = workflow_context["user_id"]
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# Create workflow run
workflow_run_id = await _create_inbound_workflow_run(
workflow_id,
workflow_context["user_id"],
@ -896,6 +886,17 @@ async def handle_inbound_telephony(
telephony_configuration_id=workflow_context["telephony_configuration_id"],
from_phone_number_id=workflow_context.get("from_phone_number_id"),
)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# Generate response URLs
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()

View file

@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import (
register_ws_sender,
unregister_ws_sender,
)
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
router = APIRouter(prefix="/ws")
@ -329,7 +329,11 @@ class SignalingManager:
# Check Dograh quota before initiating the call (apply per-workflow
# model_overrides so we evaluate the keys this workflow will use).
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
# Send error response for quota issues
await ws.send_json(

View file

@ -10,7 +10,7 @@ from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user_with_selected_organization
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
from api.services.workflow.text_chat_session_service import (
TextChatPendingTurnLostError,
TextChatSessionExecutionError,
@ -96,8 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
}
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
async def _ensure_text_chat_quota(
user: UserModel,
workflow_id: int,
workflow_run_id: int,
) -> None:
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
@ -153,8 +161,6 @@ async def create_text_chat_session(
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
await _ensure_text_chat_quota(user, workflow_id)
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
@ -170,6 +176,7 @@ async def create_text_chat_session(
raise HTTPException(status_code=404, detail=str(e))
set_current_run_id(workflow_run.id)
await _ensure_text_chat_quota(user, workflow_id, workflow_run.id)
annotations = {
"tester": {
@ -229,7 +236,7 @@ async def append_text_chat_message(
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
await _ensure_text_chat_quota(user, workflow_id)
await _ensure_text_chat_quota(user, workflow_id, run_id)
try:
text_session = await append_text_chat_user_message(

View file

@ -15,6 +15,7 @@ from api.services.campaign.errors import (
PhoneNumberPoolExhaustedError,
)
from api.services.campaign.rate_limiter import rate_limiter
from api.services.quota_service import authorize_workflow_run_start
from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
},
)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
error_message = quota_result.error_message or "Quota exceeded"
logger.warning(
f"Campaign {campaign.id} quota check failed for workflow run "
f"{workflow_run.id}: {error_message}"
)
await db_client.update_workflow_run(
run_id=workflow_run.id,
is_completed=True,
state=WorkflowRunState.COMPLETED.value,
gathered_context={"error": error_message},
)
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
if mapping:
org_id, mapped_slot_id = mapping
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
workflow_run.id
)
if from_number_mapping:
fn_org_id, fn_number, fn_tcid = from_number_mapping
await rate_limiter.release_from_number(
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
)
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
raise ValueError(error_message)
# Initiate call via telephony provider
try:
# Construct webhook URL with parameters

View file

@ -2,11 +2,8 @@ from __future__ import annotations
from typing import Any
from loguru import logger
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
@ -48,27 +45,10 @@ async def ensure_mps_correlation_id(
if not uses_managed_model_services_v2(ai_model_config):
return None
service_key = _get_dograh_service_api_key(ai_model_config)
if not service_key:
raise ValueError(
"Managed model services v2 requires a Dograh service key before the run starts."
)
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
raise ValueError(
"Managed model services v2 requires workflow run authorization before "
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
)
correlation_id = response.get("correlation_id")
if not correlation_id:
raise ValueError("MPS correlation-id response did not include correlation_id")
correlation_id = str(correlation_id)
logger.info(
"Minted MPS correlation id {} for workflow run {}",
correlation_id,
workflow_run_id,
)
return correlation_id
def _is_dograh_service(service: Any) -> bool:
@ -78,7 +58,7 @@ def _is_dograh_service(service: Any) -> bool:
)
def _get_dograh_service_api_key(
def get_dograh_service_api_key(
ai_model_config: EffectiveAIModelConfiguration,
) -> str | None:
for section_name in ("llm", "tts", "stt", "embeddings"):

View file

@ -478,6 +478,50 @@ class MPSServiceKeyClient:
response=response,
)
async def authorize_workflow_run_start(
self,
*,
organization_id: int,
workflow_run_id: int | None = None,
service_key: Optional[str] = None,
require_correlation_id: bool = False,
minimum_credits: float | None = None,
metadata: Optional[dict] = None,
created_by: Optional[str] = None,
) -> dict:
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
payload = {
"workflow_run_id": workflow_run_id,
"service_key": service_key,
"require_correlation_id": require_correlation_id,
"metadata": metadata or {},
}
if minimum_credits is not None:
payload["minimum_credits"] = minimum_credits
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
json=payload,
headers=self._get_headers(
organization_id=organization_id,
created_by=created_by,
),
)
if response.status_code == 200:
return response.json()
logger.error(
"Failed to authorize MPS workflow run start: "
f"{response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to authorize MPS workflow run start: {response.text}",
request=response.request,
response=response,
)
async def create_correlation_id(
self,
*,

View file

@ -5,17 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
"""
from dataclasses import dataclass
from typing import Any
from loguru import logger
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
from api.services.configuration.registry import ServiceProviders
from api.services.managed_model_services import (
MPS_CORRELATION_ID_CONTEXT_KEY,
get_dograh_service_api_key,
uses_managed_model_services_v2,
)
from api.services.mps_service_key_client import mps_service_key_client
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
)
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
"You have exhausted your Dograh credits. "
"Please purchase more credits from /billing "
"or change providers in Models configurations."
)
@dataclass
class QuotaCheckResult:
@ -26,116 +47,359 @@ class QuotaCheckResult:
error_code: str = ""
async def check_dograh_quota(
user: UserModel, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check if user has sufficient Dograh quota for making a call.
This function checks if the user is using any Dograh services (LLM, STT, TTS)
and validates that they have sufficient credits remaining.
When ``workflow_id`` is provided, the workflow's per-workflow
``model_overrides`` are merged onto the user's global config so the quota
check runs against the credentials that will actually be used for the call
(rather than always falling back to the user's defaults).
Args:
user: The user to check quota for
workflow_id: Optional workflow whose ``model_overrides`` should be
applied when resolving the effective service config.
Returns:
QuotaCheckResult with has_quota=True if user has sufficient quota or
is not using Dograh services, or has_quota=False with error_message
if quota is insufficient.
"""
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
organization_id = user.selected_organization_id
workflow_configurations = None
return float(value)
except (TypeError, ValueError):
return default
if workflow_id is not None:
workflow = await db_client.get_workflow_by_id(workflow_id)
if workflow:
organization_id = workflow.organization_id
workflow_configurations = workflow.workflow_configurations
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user.id,
organization_id=organization_id,
workflow_configurations=workflow_configurations,
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="insufficient_credits",
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
)
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
)
def _service_uses_dograh(service: Any) -> bool:
provider = getattr(service, "provider", None)
return (
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
)
def _dograh_api_keys(user_config: Any) -> set[str]:
api_keys: set[str] = set()
for section_name in ("llm", "stt", "tts", "embeddings"):
service = getattr(user_config, section_name, None)
if not _service_uses_dograh(service):
continue
if hasattr(service, "get_all_api_keys"):
all_api_keys = [
api_key
for api_key in service.get_all_api_keys()
if isinstance(api_key, str) and api_key
]
if all_api_keys:
api_keys.update(all_api_keys)
continue
api_key = getattr(service, "api_key", None)
if api_key:
api_keys.add(api_key)
return api_keys
async def _store_run_correlation_id(
workflow_run_id: int | None,
correlation_id: str | None,
) -> None:
if not workflow_run_id or not correlation_id:
return
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if not workflow_run:
logger.warning(
"Could not store MPS correlation id for missing workflow run {}",
workflow_run_id,
)
return
initial_context = dict(workflow_run.initial_context or {})
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
return
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
await db_client.update_workflow_run(
workflow_run_id,
initial_context=initial_context,
)
async def _authorize_hosted_workflow_run_start(
*,
workflow_owner: UserModel,
organization_id: int | None,
workflow_id: int | None,
workflow_run_id: int | None,
user_config: Any,
) -> tuple[QuotaCheckResult, bool]:
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
if DEPLOYMENT_MODE == "oss" or organization_id is None:
return QuotaCheckResult(has_quota=True), False
requires_correlation = bool(
workflow_run_id and uses_managed_model_services_v2(user_config)
)
service_key = (
get_dograh_service_api_key(user_config) if requires_correlation else None
)
if requires_correlation and not service_key:
return (
QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
),
True,
)
# Check if user is using any Dograh service
using_dograh = False
dograh_api_keys = set()
try:
authorization = await mps_service_key_client.authorize_workflow_run_start(
organization_id=organization_id,
workflow_run_id=workflow_run_id,
service_key=service_key,
require_correlation_id=requires_correlation,
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by=(
str(workflow_owner.provider_id)
if workflow_owner.provider_id is not None
else None
),
metadata={
"dograh_user_id": str(workflow_owner.id),
"workflow_id": workflow_id,
},
)
except Exception as e:
logger.error(
"Failed to authorize workflow start with MPS for org {}: {}",
organization_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.llm.api_key)
billing_mode = authorization.get("billing_mode")
if billing_mode != "v2":
return QuotaCheckResult(has_quota=True), False
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.stt.api_key)
remaining = _safe_float(authorization.get("remaining_credits"))
if (
not authorization.get("allowed", False)
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
):
logger.warning(
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return _insufficient_billing_v2_quota_result(), True
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
try:
await _store_run_correlation_id(
workflow_run_id,
authorization.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to store MPS correlation id for workflow_run_id {}: {}",
workflow_run_id,
e,
)
return (
QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
),
True,
)
logger.info(
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
organization_id,
remaining,
)
return QuotaCheckResult(has_quota=True), True
if (
user_config.embeddings
and user_config.embeddings.provider == ServiceProviders.DOGRAH
):
using_dograh = True
dograh_api_keys.add(user_config.embeddings.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return QuotaCheckResult(has_quota=True)
async def _authorize_legacy_dograh_keys(
*,
dograh_api_keys: set[str],
organization_id: int | None,
workflow_owner: UserModel,
) -> QuotaCheckResult:
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key,
organization_id=organization_id,
created_by=workflow_owner.provider_id,
)
remaining = usage.get("remaining_credits", 0.0)
# Check quota for ALL Dograh keys
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key,
organization_id=organization_id,
created_by=user.provider_id,
# Require at least $0.10 for a short call
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
remaining = usage.get("remaining_credits", 0.0)
return _insufficient_legacy_quota_result()
# Require at least $0.10 for a short call
if remaining < 0.10:
logger.warning(
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
f"${remaining:.2f} remaining"
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_exceeded",
error_message=(
"You have exhausted your trial credits. "
"Please email founders@dograh.com for additional Dograh credits "
"or change providers in Models configurations."
),
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
logger.info(
f"Dograh quota check passed for key ...{api_key[-8:]}: "
f"{remaining:.2f} credits remaining"
)
except Exception as e:
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
error_str = str(e)
if "404" in error_str or "not found" in error_str.lower():
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
error_code="invalid_service_key",
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def _authorize_oss_managed_v2_correlation(
*,
workflow_id: int,
workflow_run_id: int | None,
user_config: Any,
) -> QuotaCheckResult:
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
return QuotaCheckResult(has_quota=True)
service_key = get_dograh_service_api_key(user_config)
if not service_key:
return QuotaCheckResult(
has_quota=False,
error_code="invalid_service_key",
error_message=(
"You have invalid keys in your model configuration. "
"Please validate the service keys."
),
)
try:
response = await mps_service_key_client.create_correlation_id(
service_key=service_key,
workflow_run_id=workflow_run_id,
)
await _store_run_correlation_id(
workflow_run_id,
response.get("correlation_id"),
)
except Exception as e:
logger.error(
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
workflow_id,
workflow_run_id,
e,
)
return QuotaCheckResult(
has_quota=False,
error_code="quota_check_failed",
error_message="Could not verify Dograh credits. Please try again.",
)
return QuotaCheckResult(has_quota=True)
async def authorize_workflow_run_start(
*,
workflow_id: int,
workflow_run_id: int | None = None,
actor_user: UserModel | None = None,
) -> QuotaCheckResult:
"""Authorize a workflow run before any billable call/text runtime starts.
The workflow organization is the billing subject for hosted v2. The workflow
owner is used only to resolve the effective model configuration and legacy
service-key metadata.
"""
try:
workflow = await db_client.get_workflow_by_id(workflow_id)
if not workflow:
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
actor_org_id = getattr(actor_user, "selected_organization_id", None)
if actor_org_id is not None and actor_org_id != workflow.organization_id:
logger.warning(
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
actor_org_id,
workflow_id,
workflow.organization_id,
)
return QuotaCheckResult(
has_quota=False,
error_code="workflow_not_found",
error_message="Workflow not found",
)
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
if not workflow_owner:
return QuotaCheckResult(
has_quota=False,
error_code="user_not_found",
error_message="User not found",
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_owner.id,
organization_id=workflow.organization_id,
workflow_configurations=workflow.workflow_configurations,
)
if DEPLOYMENT_MODE != "oss":
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
workflow_owner=workflow_owner,
organization_id=workflow.organization_id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
if hosted_enforced or not hosted_result.has_quota:
return hosted_result
dograh_api_keys = _dograh_api_keys(user_config)
if not dograh_api_keys:
return QuotaCheckResult(has_quota=True)
legacy_result = await _authorize_legacy_dograh_keys(
dograh_api_keys=dograh_api_keys,
organization_id=(
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
),
workflow_owner=workflow_owner,
)
if not legacy_result.has_quota:
return legacy_result
if DEPLOYMENT_MODE == "oss":
return await _authorize_oss_managed_v2_correlation(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
user_config=user_config,
)
return QuotaCheckResult(has_quota=True)
@ -143,30 +407,3 @@ async def check_dograh_quota(
logger.error(f"Error during quota check: {str(e)}")
# On unexpected error, allow the call to proceed
return QuotaCheckResult(has_quota=True)
async def check_dograh_quota_by_user_id(
user_id: int, workflow_id: int | None = None
) -> QuotaCheckResult:
"""Check Dograh quota by user ID.
Convenience function that fetches the user and then checks quota. When
``workflow_id`` is provided, the workflow's ``model_overrides`` are
applied so the quota check evaluates the credentials that will actually
be used for the call.
Args:
user_id: The ID of the user to check quota for
workflow_id: Optional workflow whose per-workflow overrides should
be applied to the user's config before checking quota.
Returns:
QuotaCheckResult with quota status
"""
user = await db_client.get_user_by_id(user_id)
if not user:
return QuotaCheckResult(
has_quota=False,
error_message="User not found",
)
return await check_dograh_quota(user, workflow_id=workflow_id)

View file

@ -26,7 +26,7 @@ from loguru import logger
from api.constants import REDIS_URL
from api.db import db_client
from api.enums import CallType, WorkflowRunMode
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.transfer_event_protocol import (
TransferEvent,
@ -564,19 +564,7 @@ class ARIConnection:
user_id = workflow.user_id
# 3. Check quota (apply per-workflow model_overrides).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=inbound_workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 4. Create workflow run
# 3. Create workflow run
call_id = channel_id
workflow_run = await db_client.create_workflow_run(
name=f"ARI Inbound {caller_number}",
@ -602,6 +590,20 @@ class ARIConnection:
f"(caller={caller_number}, called={called_number})"
)
# 4. Check quota after the run exists so hosted v2 can mint and
# store the MPS correlation id before the pipeline starts.
quota_result = await authorize_workflow_run_start(
workflow_id=inbound_workflow_id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
f"— hanging up inbound call {channel_id}"
)
await self._delete_channel(channel_id)
return
# 5. Answer the inbound channel
await self._answer_channel(channel_id)

View file

@ -33,7 +33,9 @@ def test_create_cartesia_tts_service_passes_selected_model():
transport_in_sample_rate=16000,
)
with patch("api.services.pipecat.service_factory.CartesiaTTSService") as mock_service:
with patch(
"api.services.pipecat.service_factory.CartesiaTTSService"
) as mock_service:
create_tts_service(user_config, audio_config)
assert mock_service.call_count == 1

View file

@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig:
"api.services.campaign.campaign_call_dispatcher.get_backend_endpoints",
AsyncMock(return_value=("https://example.com", None)),
),
patch(
"api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start",
AsyncMock(
return_value=SimpleNamespace(has_quota=True, error_message="")
),
),
):
mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1))
mock_db.create_workflow_run = AsyncMock(return_value=workflow_run)

View file

@ -175,6 +175,76 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
]
@pytest.mark.asyncio
async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch):
calls = []
class FakeAsyncClient:
def __init__(self, timeout):
self.timeout = timeout
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def post(self, url, json, headers):
calls.append(("POST", url, json, headers))
return _Response(
200,
{
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
},
)
monkeypatch.setattr(
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
)
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
monkeypatch.setattr(
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
)
client = MPSServiceKeyClient()
assert await client.authorize_workflow_run_start(
organization_id=42,
workflow_run_id=88,
service_key="mps_sk_paid",
require_correlation_id=True,
minimum_credits=0.1,
metadata={"workflow_id": 7},
created_by="provider-123",
) == {
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
}
assert calls == [
(
"POST",
f"{client.base_url}/api/v1/billing/accounts/42/run-authorization",
{
"workflow_run_id": 88,
"service_key": "mps_sk_paid",
"require_correlation_id": True,
"minimum_credits": 0.1,
"metadata": {"workflow_id": 7},
},
{
"Content-Type": "application/json",
"X-Secret-Key": "mps-secret",
"X-Organization-Id": "42",
},
)
]
@pytest.mark.asyncio
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
calls = []

View file

@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner():
with (
patch("api.routes.public_agent.db_client") as mock_db,
patch(
"api.routes.public_agent.check_dograh_quota_by_user_id",
"api.routes.public_agent.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner():
)
assert response.status_code == 200
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
quota_mock.assert_awaited_once_with(
workflow_id=workflow.id,
workflow_run_id=501,
)
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
with (
patch("api.routes.public_agent.db_client") as mock_db,
patch(
"api.routes.public_agent.check_dograh_quota_by_user_id",
"api.routes.public_agent.authorize_workflow_run_start",
new=quota_mock,
),
patch(

View file

@ -0,0 +1,369 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from api.services import quota_service
from api.services.configuration.registry import ServiceProviders
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
def _dograh_config(
api_key: str = "mps_sk_12345678",
*,
managed_service_version: int = 1,
):
return SimpleNamespace(
managed_service_version=managed_service_version,
llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key),
stt=None,
tts=None,
embeddings=None,
)
def _byok_config():
return SimpleNamespace(
managed_service_version=2,
llm=SimpleNamespace(provider="openai", api_key="sk-openai"),
stt=None,
tts=None,
embeddings=None,
)
def _workflow():
return SimpleNamespace(
id=7,
user_id=123,
organization_id=42,
workflow_configurations={"model_overrides": {}},
)
def _workflow_owner():
return SimpleNamespace(
id=123,
provider_id="provider-123",
)
def _actor():
return SimpleNamespace(
id=456,
provider_id="actor-456",
selected_organization_id=42,
)
def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None):
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_by_id",
AsyncMock(return_value=workflow or _workflow()),
)
monkeypatch.setattr(
quota_service.db_client,
"get_user_by_id",
AsyncMock(return_value=owner or _workflow_owner()),
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2(
monkeypatch,
):
get_config = AsyncMock(return_value=_dograh_config())
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
}
)
check_usage = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is True
get_config.assert_awaited_once_with(
user_id=123,
organization_id=42,
workflow_configurations={"model_overrides": {}},
)
authorize.assert_awaited_once_with(
organization_id=42,
workflow_run_id=None,
service_key=None,
require_correlation_id=False,
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by="provider-123",
metadata={"dograh_user_id": "123", "workflow_id": 7},
)
check_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing(
monkeypatch,
):
get_config = AsyncMock(return_value=_byok_config())
authorize = AsyncMock(
return_value={
"allowed": False,
"billing_mode": "v2",
"remaining_credits": "0.0000",
"error": "insufficient_credits",
}
)
check_usage = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is False
assert result.error_code == "insufficient_credits"
assert "/billing" in result.error_message
assert "founders@dograh.com" not in result.error_message
authorize.assert_awaited_once()
check_usage.assert_not_awaited()
@pytest.mark.asyncio
async def test_authorize_workflow_run_v1_uses_legacy_key_usage(
monkeypatch,
):
api_key = "mps_sk_12345678"
get_config = AsyncMock(return_value=_dograh_config(api_key))
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v1",
"remaining_credits": "0.0000",
}
)
check_usage = AsyncMock(
return_value={"total_credits_used": 500.0, "remaining_credits": 0.0}
)
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
assert result.has_quota is False
assert result.error_code == "quota_exceeded"
assert "founders@dograh.com" in result.error_message
assert "/billing" not in result.error_message
authorize.assert_awaited_once()
check_usage.assert_awaited_once_with(
api_key,
organization_id=42,
created_by="provider-123",
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
monkeypatch,
):
api_key = "mps_sk_12345678"
workflow_run = SimpleNamespace(initial_context={"existing": "value"})
get_config = AsyncMock(
return_value=_dograh_config(api_key, managed_service_version=2)
)
authorize = AsyncMock(
return_value={
"allowed": True,
"billing_mode": "v2",
"remaining_credits": "25.0000",
"correlation_id": "mps-corr-123",
}
)
update_workflow_run = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
quota_service.db_client,
"update_workflow_run",
update_workflow_run,
)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
AsyncMock(),
)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
workflow_run_id=88,
)
assert result.has_quota is True
authorize.assert_awaited_once_with(
organization_id=42,
workflow_run_id=88,
service_key=api_key,
require_correlation_id=True,
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
created_by="provider-123",
metadata={"dograh_user_id": "123", "workflow_id": 7},
)
update_workflow_run.assert_awaited_once_with(
88,
initial_context={
"existing": "value",
MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123",
},
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
monkeypatch,
):
api_key = "mps_sk_12345678"
workflow_run = SimpleNamespace(initial_context={})
get_config = AsyncMock(
return_value=_dograh_config(api_key, managed_service_version=2)
)
hosted_authorize = AsyncMock()
check_usage = AsyncMock(
return_value={"total_credits_used": 1.0, "remaining_credits": 499.0}
)
create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"})
update_workflow_run = AsyncMock()
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss")
_patch_workflow_context(monkeypatch)
monkeypatch.setattr(
quota_service.db_client,
"get_workflow_run_by_id",
AsyncMock(return_value=workflow_run),
)
monkeypatch.setattr(
quota_service.db_client,
"update_workflow_run",
update_workflow_run,
)
monkeypatch.setattr(
quota_service,
"get_effective_ai_model_configuration_for_workflow",
get_config,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"authorize_workflow_run_start",
hosted_authorize,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"check_service_key_usage",
check_usage,
)
monkeypatch.setattr(
quota_service.mps_service_key_client,
"create_correlation_id",
create_correlation,
)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
workflow_run_id=88,
)
assert result.has_quota is True
hosted_authorize.assert_not_awaited()
check_usage.assert_awaited_once_with(
api_key,
organization_id=None,
created_by="provider-123",
)
create_correlation.assert_awaited_once_with(
service_key=api_key,
workflow_run_id=88,
)
update_workflow_run.assert_awaited_once_with(
88,
initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"},
)
@pytest.mark.asyncio
async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch):
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
_patch_workflow_context(monkeypatch)
result = await quota_service.authorize_workflow_run_start(
workflow_id=7,
actor_user=SimpleNamespace(selected_organization_id=999),
)
assert result.has_quota is False
assert result.error_code == "workflow_not_found"

View file

@ -1,5 +1,5 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import ANY, AsyncMock, Mock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
)
assert response.status_code == 200
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
quota_mock.assert_awaited_once_with(
workflow_id=workflow.id,
workflow_run_id=501,
actor_user=ANY,
)
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
create_call = mock_db.create_workflow_run.await_args
@ -119,7 +123,7 @@ def test_initiate_call_uses_organization_preference_phone_number():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(
@ -173,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
"api.routes.telephony.authorize_workflow_run_start",
new=quota_mock,
),
patch(

View file

@ -1105,7 +1105,7 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
async with test_client_factory(user) as client:
with patch(
"api.routes.workflow_text_chat.check_dograh_quota",
"api.routes.workflow_text_chat.authorize_workflow_run_start",
new=AsyncMock(
return_value=SimpleNamespace(
has_quota=False,
@ -1120,11 +1120,16 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
assert create_response.status_code == 402
assert create_response.json()["detail"] == "Quota exceeded"
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
runs, total_count = await db_session.get_workflow_runs_by_workflow_id(
workflow.id,
organization_id=workflow.organization_id,
)
assert total_count == 0
assert total_count == 1
text_session = await db_session.get_workflow_run_text_session(
runs[0].id,
organization_id=workflow.organization_id,
)
assert text_session is None
@pytest.mark.asyncio
@ -1168,7 +1173,7 @@ async def test_text_chat_append_rejects_quota_without_mutating_session(
async with test_client_factory(user) as client:
with (
patch(
"api.routes.workflow_text_chat.check_dograh_quota",
"api.routes.workflow_text_chat.authorize_workflow_run_start",
new=AsyncMock(
side_effect=[
SimpleNamespace(has_quota=True, error_message=""),

View file

@ -147,7 +147,7 @@ export function EmbeddedVoiceTester({
onOpenChange={setApiKeyModalOpen}
error={apiKeyError}
errorCode={apiKeyErrorCode}
onNavigateToCredits={() => router.push("/api-keys")}
onNavigateToBilling={() => router.push("/billing")}
onNavigateToModelConfig={() => router.push("/model-configurations")}
/>

View file

@ -8,7 +8,7 @@ interface ApiKeyErrorDialogProps {
onOpenChange: (open: boolean) => void;
error: string | null;
errorCode: string | null;
onNavigateToCredits: () => void;
onNavigateToBilling: () => void;
onNavigateToModelConfig: () => void;
}
@ -17,15 +17,16 @@ export const ApiKeyErrorDialog = ({
onOpenChange,
error,
errorCode,
onNavigateToCredits,
onNavigateToBilling,
onNavigateToModelConfig,
}: ApiKeyErrorDialogProps) => {
const isQuotaError = errorCode === 'quota_exceeded';
const isBillingCreditsError = errorCode === 'insufficient_credits';
const isQuotaError = isBillingCreditsError || errorCode === 'quota_exceeded';
const title = isQuotaError ? "Insufficient Credits" : "API Configuration Error";
const icon = isQuotaError ? <CreditCard className="h-5 w-5 text-orange-500" /> : <Key className="h-5 w-5 text-red-500" />;
const buttonText = isQuotaError ? "Add Credits" : "Go to Model Configurations";
const onNavigate = isQuotaError ? onNavigateToCredits : onNavigateToModelConfig;
const buttonText = isBillingCreditsError ? "Go to Billing" : "Go to Model Configurations";
const onNavigate = isBillingCreditsError ? onNavigateToBilling : onNavigateToModelConfig;
return (
<Dialog open={open} onOpenChange={onOpenChange}>
@ -40,9 +41,9 @@ export const ApiKeyErrorDialog = ({
<AlertCircle className="h-4 w-4 text-muted-foreground mt-0.5 flex-shrink-0" />
<div className="text-sm space-y-1">
<p className="font-medium text-foreground">{error}</p>
{isQuotaError && (
{isBillingCreditsError && (
<p className="text-muted-foreground">
Your Dograh service credits are too low to start a call.
Purchase credits from Billing to continue using Dograh-managed models.
</p>
)}
</div>

View file

@ -19,6 +19,13 @@ interface UseWebSocketRTCProps {
onNodeTransition?: (transition: ConversationNodeTransitionItem) => void;
}
const HANDLED_SERVICE_ERROR_TYPES = new Set([
'quota_exceeded',
'insufficient_credits',
'invalid_service_key',
'quota_check_failed',
]);
export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initialContextVariables, onNodeTransition }: UseWebSocketRTCProps) => {
const [connectionStatus, setConnectionStatus] = useState<'idle' | 'connecting' | 'connected' | 'failed'>('idle');
const [connectionActive, setConnectionActive] = useState(false);
@ -265,9 +272,7 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
case 'error':
// Check if this is a quota/service key error
if (message.payload?.error_type === 'quota_exceeded' ||
message.payload?.error_type === 'invalid_service_key' ||
message.payload?.error_type === 'quota_check_failed') {
if (HANDLED_SERVICE_ERROR_TYPES.has(message.payload?.error_type)) {
// Log as info since it's a handled business logic case
logger.info('Quota/service key error, showing user dialog:', message.payload.message);