diff --git a/api/routes/agent_stream.py b/api/routes/agent_stream.py index b593a318..32bf5743 100644 --- a/api/routes/agent_stream.py +++ b/api/routes/agent_stream.py @@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect from api.db import db_client from api.enums import CallType, WorkflowRunState -from api.services.quota_service import check_dograh_quota_by_user_id +from api.services.quota_service import authorize_workflow_run_start from api.services.telephony import registry as telephony_registry router = APIRouter(prefix="/agent-stream") @@ -67,19 +67,6 @@ async def agent_stream_websocket( await websocket.close(code=1008, reason="Workflow not found") return - quota_result = await check_dograh_quota_by_user_id( - workflow.user_id, workflow_id=workflow.id - ) - if not quota_result.has_quota: - logger.warning( - f"agent-stream quota exceeded for user {workflow.user_id}: " - f"{quota_result.error_message}" - ) - await websocket.close( - code=1008, reason=quota_result.error_message or "Quota exceeded" - ) - return - numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000 workflow_run_name = f"WR-AGS-{numeric_suffix:08d}" call_id = params.get("callId") or params.get("CallSid") @@ -108,6 +95,20 @@ async def agent_stream_websocket( set_current_run_id(workflow_run.id) set_current_org_id(workflow.organization_id) + quota_result = await authorize_workflow_run_start( + workflow_id=workflow.id, + workflow_run_id=workflow_run.id, + ) + if not quota_result.has_quota: + logger.warning( + f"agent-stream quota exceeded for user {workflow.user_id}: " + f"{quota_result.error_message}" + ) + await websocket.close( + code=1008, reason=quota_result.error_message or "Quota exceeded" + ) + return + await db_client.update_workflow_run( run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value ) diff --git a/api/routes/campaign.py b/api/routes/campaign.py index cb5f541c..90697f91 100644 --- a/api/routes/campaign.py +++ b/api/routes/campaign.py @@ -18,7 +18,7 @@ from api.services.auth.depends import get_user from api.services.campaign.runner import campaign_runner_service from api.services.campaign.source_sync import CampaignSourceSyncService from api.services.campaign.source_sync_factory import get_sync_service -from api.services.quota_service import check_dograh_quota +from api.services.quota_service import authorize_workflow_run_start from api.services.reports import generate_campaign_report_csv from api.services.storage import storage_fs @@ -550,7 +550,10 @@ async def start_campaign( # Check Dograh quota before starting campaign (apply per-workflow # model_overrides so we evaluate the keys this campaign will use). - quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id) + quota_result = await authorize_workflow_run_start( + workflow_id=campaign.workflow_id, + actor_user=user, + ) if not quota_result.has_quota: raise HTTPException(status_code=402, detail=quota_result.error_message) @@ -872,7 +875,10 @@ async def resume_campaign( # Check Dograh quota before resuming campaign (apply per-workflow # model_overrides so we evaluate the keys this campaign will use). - quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id) + quota_result = await authorize_workflow_run_start( + workflow_id=campaign.workflow_id, + actor_user=user, + ) if not quota_result.has_quota: raise HTTPException(status_code=402, detail=quota_result.error_message) diff --git a/api/routes/public_agent.py b/api/routes/public_agent.py index 93d3f1e8..64706fb5 100644 --- a/api/routes/public_agent.py +++ b/api/routes/public_agent.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from api.db import db_client from api.enums import TriggerState, WorkflowStatus -from api.services.quota_service import check_dograh_quota_by_user_id +from api.services.quota_service import authorize_workflow_run_start from api.services.telephony.factory import ( get_default_telephony_provider, get_telephony_provider_by_id, @@ -179,14 +179,6 @@ async def _execute_resolved_target( """Shared execution path once the target workflow has been resolved.""" execution_user_id = _get_execution_user_id(target.workflow) - # Check Dograh quota using the workflow owner's config and model overrides. - quota_result = await check_dograh_quota_by_user_id( - execution_user_id, - workflow_id=target.workflow.id, - ) - if not quota_result.has_quota: - raise HTTPException(status_code=402, detail=quota_result.error_message) - # Get telephony provider — either the caller-specified config (validated # against the workflow's org) or the org's default config. if request.telephony_configuration_id is not None: @@ -268,6 +260,15 @@ async def _execute_resolved_target( f"to phone number {request.phone_number}" ) + # Check Dograh quota after the run exists so hosted v2 can mint and store + # the MPS correlation id before the provider starts the call. + quota_result = await authorize_workflow_run_start( + workflow_id=target.workflow.id, + workflow_run_id=workflow_run.id, + ) + if not quota_result.has_quota: + raise HTTPException(status_code=402, detail=quota_result.error_message) + # 9. Construct webhook URL for telephony provider callback backend_endpoint, _ = await get_backend_endpoints() webhook_endpoint = provider.WEBHOOK_ENDPOINT diff --git a/api/routes/telephony.py b/api/routes/telephony.py index 7dbeab93..c9ffd0df 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState from api.errors.telephony_errors import TelephonyError from api.sdk_expose import sdk_expose from api.services.auth.depends import get_user -from api.services.quota_service import check_dograh_quota_by_user_id +from api.services.quota_service import authorize_workflow_run_start from api.services.telephony.call_transfer_manager import get_call_transfer_manager from api.services.telephony.factory import ( get_all_telephony_providers, @@ -136,14 +136,6 @@ async def initiate_call( raise HTTPException(status_code=404, detail="Workflow not found") execution_user_id = _get_execution_user_id(workflow) - # Check Dograh quota before initiating the call (apply per-workflow - # model_overrides so the keys we will actually use are the ones checked). - quota_result = await check_dograh_quota_by_user_id( - execution_user_id, workflow_id=workflow.id - ) - if not quota_result.has_quota: - raise HTTPException(status_code=402, detail=quota_result.error_message) - # Determine the workflow run mode based on provider type workflow_run_mode = provider.PROVIDER_NAME @@ -186,6 +178,16 @@ async def initiate_call( ) workflow_run_name = workflow_run.name + # Check Dograh quota after the run exists so hosted v2 can mint and store + # the MPS correlation id before initiating the call. + quota_result = await authorize_workflow_run_start( + workflow_id=workflow.id, + workflow_run_id=workflow_run_id, + actor_user=user, + ) + if not quota_result.has_quota: + raise HTTPException(status_code=402, detail=quota_result.error_message) + # Construct webhook URL based on provider type backend_endpoint, _ = await get_backend_endpoints() @@ -739,19 +741,8 @@ async def handle_inbound_run(request: Request): TelephonyError.SIGNATURE_VALIDATION_FAILED ) - # 4. Quota check (use the workflow's model_overrides if set). - quota_result = await check_dograh_quota_by_user_id( - user_id, workflow_id=workflow_id - ) - if not quota_result.has_quota: - logger.warning( - f"User {user_id} has exceeded quota: {quota_result.error_message}" - ) - return provider_class.generate_validation_error_response( - TelephonyError.QUOTA_EXCEEDED - ) - - # 5. Create workflow run + return provider-shaped response. + # 5. Create workflow run + authorize quota before returning provider + # stream instructions. workflow_run_id = await _create_inbound_workflow_run( workflow_id, user_id, @@ -760,6 +751,17 @@ async def handle_inbound_run(request: Request): telephony_configuration_id=telephony_configuration_id, from_phone_number_id=phone_row.id, ) + quota_result = await authorize_workflow_run_start( + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + ) + if not quota_result.has_quota: + logger.warning( + f"User {user_id} has exceeded quota: {quota_result.error_message}" + ) + return provider_class.generate_validation_error_response( + TelephonyError.QUOTA_EXCEEDED + ) backend_endpoint, wss_backend_endpoint = await get_backend_endpoints() websocket_url = ( @@ -874,20 +876,8 @@ async def handle_inbound_telephony( logger.error(f"Request validation failed: {error_type}") return provider_class.generate_validation_error_response(error_type) - # Check quota before processing (apply per-workflow model_overrides). + # Create workflow run. user_id = workflow_context["user_id"] - quota_result = await check_dograh_quota_by_user_id( - user_id, workflow_id=workflow_id - ) - if not quota_result.has_quota: - logger.warning( - f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}" - ) - return provider_class.generate_validation_error_response( - TelephonyError.QUOTA_EXCEEDED - ) - - # Create workflow run workflow_run_id = await _create_inbound_workflow_run( workflow_id, workflow_context["user_id"], @@ -896,6 +886,17 @@ async def handle_inbound_telephony( telephony_configuration_id=workflow_context["telephony_configuration_id"], from_phone_number_id=workflow_context.get("from_phone_number_id"), ) + quota_result = await authorize_workflow_run_start( + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + ) + if not quota_result.has_quota: + logger.warning( + f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}" + ) + return provider_class.generate_validation_error_response( + TelephonyError.QUOTA_EXCEEDED + ) # Generate response URLs backend_endpoint, wss_backend_endpoint = await get_backend_endpoints() diff --git a/api/routes/webrtc_signaling.py b/api/routes/webrtc_signaling.py index f7b4eeb3..ca8d3038 100644 --- a/api/routes/webrtc_signaling.py +++ b/api/routes/webrtc_signaling.py @@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import ( register_ws_sender, unregister_ws_sender, ) -from api.services.quota_service import check_dograh_quota +from api.services.quota_service import authorize_workflow_run_start router = APIRouter(prefix="/ws") @@ -329,7 +329,11 @@ class SignalingManager: # Check Dograh quota before initiating the call (apply per-workflow # model_overrides so we evaluate the keys this workflow will use). - quota_result = await check_dograh_quota(user, workflow_id=workflow_id) + quota_result = await authorize_workflow_run_start( + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + actor_user=user, + ) if not quota_result.has_quota: # Send error response for quota issues await ws.send_json( diff --git a/api/routes/workflow_text_chat.py b/api/routes/workflow_text_chat.py index b4650118..47254330 100644 --- a/api/routes/workflow_text_chat.py +++ b/api/routes/workflow_text_chat.py @@ -10,7 +10,7 @@ from api.db import db_client from api.db.models import UserModel, WorkflowRunTextSessionModel from api.enums import WorkflowRunMode from api.services.auth.depends import get_user_with_selected_organization -from api.services.quota_service import check_dograh_quota +from api.services.quota_service import authorize_workflow_run_start from api.services.workflow.text_chat_session_service import ( TextChatPendingTurnLostError, TextChatSessionExecutionError, @@ -96,8 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]: } -async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None: - quota_result = await check_dograh_quota(user, workflow_id=workflow_id) +async def _ensure_text_chat_quota( + user: UserModel, + workflow_id: int, + workflow_run_id: int, +) -> None: + quota_result = await authorize_workflow_run_start( + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + actor_user=user, + ) if not quota_result.has_quota: raise HTTPException(status_code=402, detail=quota_result.error_message) @@ -153,8 +161,6 @@ async def create_text_chat_session( request: CreateTextChatSessionRequest, user: UserModel = Depends(get_user_with_selected_organization), ) -> WorkflowRunTextSessionResponse: - await _ensure_text_chat_quota(user, workflow_id) - session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}" try: workflow_run = await db_client.create_workflow_run( @@ -170,6 +176,7 @@ async def create_text_chat_session( raise HTTPException(status_code=404, detail=str(e)) set_current_run_id(workflow_run.id) + await _ensure_text_chat_quota(user, workflow_id, workflow_run.id) annotations = { "tester": { @@ -229,7 +236,7 @@ async def append_text_chat_message( user: UserModel = Depends(get_user_with_selected_organization), ) -> WorkflowRunTextSessionResponse: text_session = await _load_text_session_or_404(workflow_id, run_id, user) - await _ensure_text_chat_quota(user, workflow_id) + await _ensure_text_chat_quota(user, workflow_id, run_id) try: text_session = await append_text_chat_user_message( diff --git a/api/services/campaign/campaign_call_dispatcher.py b/api/services/campaign/campaign_call_dispatcher.py index 27fc2355..84a419be 100644 --- a/api/services/campaign/campaign_call_dispatcher.py +++ b/api/services/campaign/campaign_call_dispatcher.py @@ -15,6 +15,7 @@ from api.services.campaign.errors import ( PhoneNumberPoolExhaustedError, ) from api.services.campaign.rate_limiter import rate_limiter +from api.services.quota_service import authorize_workflow_run_start from api.utils.common import get_backend_endpoints if TYPE_CHECKING: @@ -339,6 +340,41 @@ class CampaignCallDispatcher: }, ) + quota_result = await authorize_workflow_run_start( + workflow_id=campaign.workflow_id, + workflow_run_id=workflow_run.id, + ) + if not quota_result.has_quota: + error_message = quota_result.error_message or "Quota exceeded" + logger.warning( + f"Campaign {campaign.id} quota check failed for workflow run " + f"{workflow_run.id}: {error_message}" + ) + await db_client.update_workflow_run( + run_id=workflow_run.id, + is_completed=True, + state=WorkflowRunState.COMPLETED.value, + gathered_context={"error": error_message}, + ) + + mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id) + if mapping: + org_id, mapped_slot_id = mapping + await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id) + await rate_limiter.delete_workflow_slot_mapping(workflow_run.id) + + from_number_mapping = await rate_limiter.get_workflow_from_number_mapping( + workflow_run.id + ) + if from_number_mapping: + fn_org_id, fn_number, fn_tcid = from_number_mapping + await rate_limiter.release_from_number( + fn_org_id, fn_number, telephony_configuration_id=fn_tcid + ) + await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id) + + raise ValueError(error_message) + # Initiate call via telephony provider try: # Construct webhook URL with parameters diff --git a/api/services/managed_model_services.py b/api/services/managed_model_services.py index b6992aaf..00c776ff 100644 --- a/api/services/managed_model_services.py +++ b/api/services/managed_model_services.py @@ -2,11 +2,8 @@ from __future__ import annotations from typing import Any -from loguru import logger - from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration from api.services.configuration.registry import ServiceProviders -from api.services.mps_service_key_client import mps_service_key_client MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id" @@ -48,27 +45,10 @@ async def ensure_mps_correlation_id( if not uses_managed_model_services_v2(ai_model_config): return None - service_key = _get_dograh_service_api_key(ai_model_config) - if not service_key: - raise ValueError( - "Managed model services v2 requires a Dograh service key before the run starts." - ) - - response = await mps_service_key_client.create_correlation_id( - service_key=service_key, - workflow_run_id=workflow_run_id, + raise ValueError( + "Managed model services v2 requires workflow run authorization before " + f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}." ) - correlation_id = response.get("correlation_id") - if not correlation_id: - raise ValueError("MPS correlation-id response did not include correlation_id") - - correlation_id = str(correlation_id) - logger.info( - "Minted MPS correlation id {} for workflow run {}", - correlation_id, - workflow_run_id, - ) - return correlation_id def _is_dograh_service(service: Any) -> bool: @@ -78,7 +58,7 @@ def _is_dograh_service(service: Any) -> bool: ) -def _get_dograh_service_api_key( +def get_dograh_service_api_key( ai_model_config: EffectiveAIModelConfiguration, ) -> str | None: for section_name in ("llm", "tts", "stt", "embeddings"): diff --git a/api/services/mps_service_key_client.py b/api/services/mps_service_key_client.py index 4f30341d..5f90380f 100644 --- a/api/services/mps_service_key_client.py +++ b/api/services/mps_service_key_client.py @@ -478,6 +478,50 @@ class MPSServiceKeyClient: response=response, ) + async def authorize_workflow_run_start( + self, + *, + organization_id: int, + workflow_run_id: int | None = None, + service_key: Optional[str] = None, + require_correlation_id: bool = False, + minimum_credits: float | None = None, + metadata: Optional[dict] = None, + created_by: Optional[str] = None, + ) -> dict: + """Authorize a hosted workflow run and optionally mint its MPS correlation.""" + payload = { + "workflow_run_id": workflow_run_id, + "service_key": service_key, + "require_correlation_id": require_correlation_id, + "metadata": metadata or {}, + } + if minimum_credits is not None: + payload["minimum_credits"] = minimum_credits + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization", + json=payload, + headers=self._get_headers( + organization_id=organization_id, + created_by=created_by, + ), + ) + + if response.status_code == 200: + return response.json() + + logger.error( + "Failed to authorize MPS workflow run start: " + f"{response.status_code} - {response.text}" + ) + raise httpx.HTTPStatusError( + f"Failed to authorize MPS workflow run start: {response.text}", + request=response.request, + response=response, + ) + async def create_correlation_id( self, *, diff --git a/api/services/quota_service.py b/api/services/quota_service.py index 6114ae99..6633736e 100644 --- a/api/services/quota_service.py +++ b/api/services/quota_service.py @@ -5,17 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers). """ from dataclasses import dataclass +from typing import Any from loguru import logger +from api.constants import DEPLOYMENT_MODE from api.db import db_client from api.db.models import UserModel from api.services.configuration.ai_model_configuration import ( get_effective_ai_model_configuration_for_workflow, ) from api.services.configuration.registry import ServiceProviders +from api.services.managed_model_services import ( + MPS_CORRELATION_ID_CONTEXT_KEY, + get_dograh_service_api_key, + uses_managed_model_services_v2, +) from api.services.mps_service_key_client import mps_service_key_client +MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10 + +LEGACY_QUOTA_EXCEEDED_MESSAGE = ( + "You have exhausted your trial credits. " + "Please email founders@dograh.com for additional Dograh credits " + "or change providers in Models configurations." +) + +BILLING_V2_QUOTA_EXCEEDED_MESSAGE = ( + "You have exhausted your Dograh credits. " + "Please purchase more credits from /billing " + "or change providers in Models configurations." +) + @dataclass class QuotaCheckResult: @@ -26,116 +47,359 @@ class QuotaCheckResult: error_code: str = "" -async def check_dograh_quota( - user: UserModel, workflow_id: int | None = None -) -> QuotaCheckResult: - """Check if user has sufficient Dograh quota for making a call. - - This function checks if the user is using any Dograh services (LLM, STT, TTS) - and validates that they have sufficient credits remaining. - - When ``workflow_id`` is provided, the workflow's per-workflow - ``model_overrides`` are merged onto the user's global config so the quota - check runs against the credentials that will actually be used for the call - (rather than always falling back to the user's defaults). - - Args: - user: The user to check quota for - workflow_id: Optional workflow whose ``model_overrides`` should be - applied when resolving the effective service config. - - Returns: - QuotaCheckResult with has_quota=True if user has sufficient quota or - is not using Dograh services, or has_quota=False with error_message - if quota is insufficient. - """ +def _safe_float(value: Any, default: float = 0.0) -> float: try: - organization_id = user.selected_organization_id - workflow_configurations = None + return float(value) + except (TypeError, ValueError): + return default - if workflow_id is not None: - workflow = await db_client.get_workflow_by_id(workflow_id) - if workflow: - organization_id = workflow.organization_id - workflow_configurations = workflow.workflow_configurations - user_config = await get_effective_ai_model_configuration_for_workflow( - user_id=user.id, - organization_id=organization_id, - workflow_configurations=workflow_configurations, +def _insufficient_billing_v2_quota_result() -> QuotaCheckResult: + return QuotaCheckResult( + has_quota=False, + error_code="insufficient_credits", + error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE, + ) + + +def _insufficient_legacy_quota_result() -> QuotaCheckResult: + return QuotaCheckResult( + has_quota=False, + error_code="quota_exceeded", + error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE, + ) + + +def _service_uses_dograh(service: Any) -> bool: + provider = getattr(service, "provider", None) + return ( + provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value + ) + + +def _dograh_api_keys(user_config: Any) -> set[str]: + api_keys: set[str] = set() + for section_name in ("llm", "stt", "tts", "embeddings"): + service = getattr(user_config, section_name, None) + if not _service_uses_dograh(service): + continue + if hasattr(service, "get_all_api_keys"): + all_api_keys = [ + api_key + for api_key in service.get_all_api_keys() + if isinstance(api_key, str) and api_key + ] + if all_api_keys: + api_keys.update(all_api_keys) + continue + api_key = getattr(service, "api_key", None) + if api_key: + api_keys.add(api_key) + return api_keys + + +async def _store_run_correlation_id( + workflow_run_id: int | None, + correlation_id: str | None, +) -> None: + if not workflow_run_id or not correlation_id: + return + + workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id) + if not workflow_run: + logger.warning( + "Could not store MPS correlation id for missing workflow run {}", + workflow_run_id, + ) + return + + initial_context = dict(workflow_run.initial_context or {}) + if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id: + return + + initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id + await db_client.update_workflow_run( + workflow_run_id, + initial_context=initial_context, + ) + + +async def _authorize_hosted_workflow_run_start( + *, + workflow_owner: UserModel, + organization_id: int | None, + workflow_id: int | None, + workflow_run_id: int | None, + user_config: Any, +) -> tuple[QuotaCheckResult, bool]: + """Authorize hosted v2 billing and return whether MPS handled enforcement.""" + if DEPLOYMENT_MODE == "oss" or organization_id is None: + return QuotaCheckResult(has_quota=True), False + + requires_correlation = bool( + workflow_run_id and uses_managed_model_services_v2(user_config) + ) + service_key = ( + get_dograh_service_api_key(user_config) if requires_correlation else None + ) + if requires_correlation and not service_key: + return ( + QuotaCheckResult( + has_quota=False, + error_code="invalid_service_key", + error_message=( + "You have invalid keys in your model configuration. " + "Please validate the service keys." + ), + ), + True, ) - # Check if user is using any Dograh service - using_dograh = False - dograh_api_keys = set() + try: + authorization = await mps_service_key_client.authorize_workflow_run_start( + organization_id=organization_id, + workflow_run_id=workflow_run_id, + service_key=service_key, + require_correlation_id=requires_correlation, + minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL, + created_by=( + str(workflow_owner.provider_id) + if workflow_owner.provider_id is not None + else None + ), + metadata={ + "dograh_user_id": str(workflow_owner.id), + "workflow_id": workflow_id, + }, + ) + except Exception as e: + logger.error( + "Failed to authorize workflow start with MPS for org {}: {}", + organization_id, + e, + ) + return ( + QuotaCheckResult( + has_quota=False, + error_code="quota_check_failed", + error_message="Could not verify Dograh credits. Please try again.", + ), + True, + ) - if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH: - using_dograh = True - dograh_api_keys.add(user_config.llm.api_key) + billing_mode = authorization.get("billing_mode") + if billing_mode != "v2": + return QuotaCheckResult(has_quota=True), False - if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH: - using_dograh = True - dograh_api_keys.add(user_config.stt.api_key) + remaining = _safe_float(authorization.get("remaining_credits")) + if ( + not authorization.get("allowed", False) + or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL + ): + logger.warning( + "Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining", + organization_id, + remaining, + ) + return _insufficient_billing_v2_quota_result(), True - if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH: - using_dograh = True - dograh_api_keys.add(user_config.tts.api_key) + try: + await _store_run_correlation_id( + workflow_run_id, + authorization.get("correlation_id"), + ) + except Exception as e: + logger.error( + "Failed to store MPS correlation id for workflow_run_id {}: {}", + workflow_run_id, + e, + ) + return ( + QuotaCheckResult( + has_quota=False, + error_code="quota_check_failed", + error_message="Could not verify Dograh credits. Please try again.", + ), + True, + ) + logger.info( + "Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining", + organization_id, + remaining, + ) + return QuotaCheckResult(has_quota=True), True - if ( - user_config.embeddings - and user_config.embeddings.provider == ServiceProviders.DOGRAH - ): - using_dograh = True - dograh_api_keys.add(user_config.embeddings.api_key) - # If not using Dograh, quota check passes - if not using_dograh: - return QuotaCheckResult(has_quota=True) +async def _authorize_legacy_dograh_keys( + *, + dograh_api_keys: set[str], + organization_id: int | None, + workflow_owner: UserModel, +) -> QuotaCheckResult: + for api_key in dograh_api_keys: + try: + usage = await mps_service_key_client.check_service_key_usage( + api_key, + organization_id=organization_id, + created_by=workflow_owner.provider_id, + ) + remaining = usage.get("remaining_credits", 0.0) - # Check quota for ALL Dograh keys - for api_key in dograh_api_keys: - try: - usage = await mps_service_key_client.check_service_key_usage( - api_key, - organization_id=organization_id, - created_by=user.provider_id, + # Require at least $0.10 for a short call + if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL: + logger.warning( + f"Insufficient Dograh credits for key ...{api_key[-8:]}: " + f"${remaining:.2f} remaining" ) - remaining = usage.get("remaining_credits", 0.0) + return _insufficient_legacy_quota_result() - # Require at least $0.10 for a short call - if remaining < 0.10: - logger.warning( - f"Insufficient Dograh credits for key ...{api_key[-8:]}: " - f"${remaining:.2f} remaining" - ) - return QuotaCheckResult( - has_quota=False, - error_code="quota_exceeded", - error_message=( - "You have exhausted your trial credits. " - "Please email founders@dograh.com for additional Dograh credits " - "or change providers in Models configurations." - ), - ) - - logger.info( - f"Dograh quota check passed for key ...{api_key[-8:]}: " - f"{remaining:.2f} credits remaining" - ) - except Exception as e: - logger.error(f"Failed to check quota for Dograh key: {str(e)}") - error_str = str(e) - if "404" in error_str or "not found" in error_str.lower(): - return QuotaCheckResult( - has_quota=False, - error_code="invalid_service_key", - error_message="You have invalid keys in your model configuration. Please validate the service keys.", - ) + logger.info( + f"Dograh quota check passed for key ...{api_key[-8:]}: " + f"{remaining:.2f} credits remaining" + ) + except Exception as e: + logger.error(f"Failed to check quota for Dograh key: {str(e)}") + error_str = str(e) + if "404" in error_str or "not found" in error_str.lower(): return QuotaCheckResult( has_quota=False, - error_code="quota_check_failed", - error_message="Could not verify Dograh credits. Please try again.", + error_code="invalid_service_key", + error_message="You have invalid keys in your model configuration. Please validate the service keys.", ) + return QuotaCheckResult( + has_quota=False, + error_code="quota_check_failed", + error_message="Could not verify Dograh credits. Please try again.", + ) + + return QuotaCheckResult(has_quota=True) + + +async def _authorize_oss_managed_v2_correlation( + *, + workflow_id: int, + workflow_run_id: int | None, + user_config: Any, +) -> QuotaCheckResult: + if not workflow_run_id or not uses_managed_model_services_v2(user_config): + return QuotaCheckResult(has_quota=True) + + service_key = get_dograh_service_api_key(user_config) + if not service_key: + return QuotaCheckResult( + has_quota=False, + error_code="invalid_service_key", + error_message=( + "You have invalid keys in your model configuration. " + "Please validate the service keys." + ), + ) + + try: + response = await mps_service_key_client.create_correlation_id( + service_key=service_key, + workflow_run_id=workflow_run_id, + ) + await _store_run_correlation_id( + workflow_run_id, + response.get("correlation_id"), + ) + except Exception as e: + logger.error( + "Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}", + workflow_id, + workflow_run_id, + e, + ) + return QuotaCheckResult( + has_quota=False, + error_code="quota_check_failed", + error_message="Could not verify Dograh credits. Please try again.", + ) + + return QuotaCheckResult(has_quota=True) + + +async def authorize_workflow_run_start( + *, + workflow_id: int, + workflow_run_id: int | None = None, + actor_user: UserModel | None = None, +) -> QuotaCheckResult: + """Authorize a workflow run before any billable call/text runtime starts. + + The workflow organization is the billing subject for hosted v2. The workflow + owner is used only to resolve the effective model configuration and legacy + service-key metadata. + """ + try: + workflow = await db_client.get_workflow_by_id(workflow_id) + if not workflow: + return QuotaCheckResult( + has_quota=False, + error_code="workflow_not_found", + error_message="Workflow not found", + ) + + actor_org_id = getattr(actor_user, "selected_organization_id", None) + if actor_org_id is not None and actor_org_id != workflow.organization_id: + logger.warning( + "Workflow start authorization denied: actor org {} does not match workflow {} org {}", + actor_org_id, + workflow_id, + workflow.organization_id, + ) + return QuotaCheckResult( + has_quota=False, + error_code="workflow_not_found", + error_message="Workflow not found", + ) + + workflow_owner = await db_client.get_user_by_id(workflow.user_id) + if not workflow_owner: + return QuotaCheckResult( + has_quota=False, + error_code="user_not_found", + error_message="User not found", + ) + + user_config = await get_effective_ai_model_configuration_for_workflow( + user_id=workflow_owner.id, + organization_id=workflow.organization_id, + workflow_configurations=workflow.workflow_configurations, + ) + + if DEPLOYMENT_MODE != "oss": + hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start( + workflow_owner=workflow_owner, + organization_id=workflow.organization_id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_id, + user_config=user_config, + ) + if hosted_enforced or not hosted_result.has_quota: + return hosted_result + + dograh_api_keys = _dograh_api_keys(user_config) + if not dograh_api_keys: + return QuotaCheckResult(has_quota=True) + + legacy_result = await _authorize_legacy_dograh_keys( + dograh_api_keys=dograh_api_keys, + organization_id=( + None if DEPLOYMENT_MODE == "oss" else workflow.organization_id + ), + workflow_owner=workflow_owner, + ) + if not legacy_result.has_quota: + return legacy_result + + if DEPLOYMENT_MODE == "oss": + return await _authorize_oss_managed_v2_correlation( + workflow_id=workflow.id, + workflow_run_id=workflow_run_id, + user_config=user_config, + ) return QuotaCheckResult(has_quota=True) @@ -143,30 +407,3 @@ async def check_dograh_quota( logger.error(f"Error during quota check: {str(e)}") # On unexpected error, allow the call to proceed return QuotaCheckResult(has_quota=True) - - -async def check_dograh_quota_by_user_id( - user_id: int, workflow_id: int | None = None -) -> QuotaCheckResult: - """Check Dograh quota by user ID. - - Convenience function that fetches the user and then checks quota. When - ``workflow_id`` is provided, the workflow's ``model_overrides`` are - applied so the quota check evaluates the credentials that will actually - be used for the call. - - Args: - user_id: The ID of the user to check quota for - workflow_id: Optional workflow whose per-workflow overrides should - be applied to the user's config before checking quota. - - Returns: - QuotaCheckResult with quota status - """ - user = await db_client.get_user_by_id(user_id) - if not user: - return QuotaCheckResult( - has_quota=False, - error_message="User not found", - ) - return await check_dograh_quota(user, workflow_id=workflow_id) diff --git a/api/services/telephony/ari_manager.py b/api/services/telephony/ari_manager.py index a10c05dc..2648affd 100644 --- a/api/services/telephony/ari_manager.py +++ b/api/services/telephony/ari_manager.py @@ -26,7 +26,7 @@ from loguru import logger from api.constants import REDIS_URL from api.db import db_client from api.enums import CallType, WorkflowRunMode -from api.services.quota_service import check_dograh_quota_by_user_id +from api.services.quota_service import authorize_workflow_run_start from api.services.telephony.call_transfer_manager import get_call_transfer_manager from api.services.telephony.transfer_event_protocol import ( TransferEvent, @@ -564,19 +564,7 @@ class ARIConnection: user_id = workflow.user_id - # 3. Check quota (apply per-workflow model_overrides). - quota_result = await check_dograh_quota_by_user_id( - user_id, workflow_id=inbound_workflow_id - ) - if not quota_result.has_quota: - logger.warning( - f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} " - f"— hanging up inbound call {channel_id}" - ) - await self._delete_channel(channel_id) - return - - # 4. Create workflow run + # 3. Create workflow run call_id = channel_id workflow_run = await db_client.create_workflow_run( name=f"ARI Inbound {caller_number}", @@ -602,6 +590,20 @@ class ARIConnection: f"(caller={caller_number}, called={called_number})" ) + # 4. Check quota after the run exists so hosted v2 can mint and + # store the MPS correlation id before the pipeline starts. + quota_result = await authorize_workflow_run_start( + workflow_id=inbound_workflow_id, + workflow_run_id=workflow_run.id, + ) + if not quota_result.has_quota: + logger.warning( + f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} " + f"— hanging up inbound call {channel_id}" + ) + await self._delete_channel(channel_id) + return + # 5. Answer the inbound channel await self._answer_channel(channel_id) diff --git a/api/tests/test_cartesia_tts_service_factory.py b/api/tests/test_cartesia_tts_service_factory.py index bcc12359..71e2acab 100644 --- a/api/tests/test_cartesia_tts_service_factory.py +++ b/api/tests/test_cartesia_tts_service_factory.py @@ -33,7 +33,9 @@ def test_create_cartesia_tts_service_passes_selected_model(): transport_in_sample_rate=16000, ) - with patch("api.services.pipecat.service_factory.CartesiaTTSService") as mock_service: + with patch( + "api.services.pipecat.service_factory.CartesiaTTSService" + ) as mock_service: create_tts_service(user_config, audio_config) assert mock_service.call_count == 1 diff --git a/api/tests/test_from_number_pool_isolation.py b/api/tests/test_from_number_pool_isolation.py index 3c65d10f..ae3dffbc 100644 --- a/api/tests/test_from_number_pool_isolation.py +++ b/api/tests/test_from_number_pool_isolation.py @@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig: "api.services.campaign.campaign_call_dispatcher.get_backend_endpoints", AsyncMock(return_value=("https://example.com", None)), ), + patch( + "api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start", + AsyncMock( + return_value=SimpleNamespace(has_quota=True, error_message="") + ), + ), ): mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1)) mock_db.create_workflow_run = AsyncMock(return_value=workflow_run) diff --git a/api/tests/test_mps_service_key_client.py b/api/tests/test_mps_service_key_client.py index 032f07bf..f51f2aa9 100644 --- a/api/tests/test_mps_service_key_client.py +++ b/api/tests/test_mps_service_key_client.py @@ -175,6 +175,76 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch): ] +@pytest.mark.asyncio +async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch): + calls = [] + + class FakeAsyncClient: + def __init__(self, timeout): + self.timeout = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def post(self, url, json, headers): + calls.append(("POST", url, json, headers)) + return _Response( + 200, + { + "allowed": True, + "billing_mode": "v2", + "remaining_credits": "25.0000", + "correlation_id": "mps-corr-123", + }, + ) + + monkeypatch.setattr( + "api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient + ) + monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas") + monkeypatch.setattr( + "api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret" + ) + + client = MPSServiceKeyClient() + + assert await client.authorize_workflow_run_start( + organization_id=42, + workflow_run_id=88, + service_key="mps_sk_paid", + require_correlation_id=True, + minimum_credits=0.1, + metadata={"workflow_id": 7}, + created_by="provider-123", + ) == { + "allowed": True, + "billing_mode": "v2", + "remaining_credits": "25.0000", + "correlation_id": "mps-corr-123", + } + assert calls == [ + ( + "POST", + f"{client.base_url}/api/v1/billing/accounts/42/run-authorization", + { + "workflow_run_id": 88, + "service_key": "mps_sk_paid", + "require_correlation_id": True, + "minimum_credits": 0.1, + "metadata": {"workflow_id": 7}, + }, + { + "Content-Type": "application/json", + "X-Secret-Key": "mps-secret", + "X-Organization-Id": "42", + }, + ) + ] + + @pytest.mark.asyncio async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch): calls = [] diff --git a/api/tests/test_public_agent_routes.py b/api/tests/test_public_agent_routes.py index a7849fbe..3b7ea409 100644 --- a/api/tests/test_public_agent_routes.py +++ b/api/tests/test_public_agent_routes.py @@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner(): with ( patch("api.routes.public_agent.db_client") as mock_db, patch( - "api.routes.public_agent.check_dograh_quota_by_user_id", + "api.routes.public_agent.authorize_workflow_run_start", new=quota_mock, ), patch( @@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner(): ) assert response.status_code == 200 - quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id) + quota_mock.assert_awaited_once_with( + workflow_id=workflow.id, + workflow_run_id=501, + ) mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11) create_kwargs = mock_db.create_workflow_run.await_args.kwargs @@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution(): with ( patch("api.routes.public_agent.db_client") as mock_db, patch( - "api.routes.public_agent.check_dograh_quota_by_user_id", + "api.routes.public_agent.authorize_workflow_run_start", new=quota_mock, ), patch( diff --git a/api/tests/test_quota_service.py b/api/tests/test_quota_service.py new file mode 100644 index 00000000..8e2ee6f5 --- /dev/null +++ b/api/tests/test_quota_service.py @@ -0,0 +1,369 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from api.services import quota_service +from api.services.configuration.registry import ServiceProviders +from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY + + +def _dograh_config( + api_key: str = "mps_sk_12345678", + *, + managed_service_version: int = 1, +): + return SimpleNamespace( + managed_service_version=managed_service_version, + llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key), + stt=None, + tts=None, + embeddings=None, + ) + + +def _byok_config(): + return SimpleNamespace( + managed_service_version=2, + llm=SimpleNamespace(provider="openai", api_key="sk-openai"), + stt=None, + tts=None, + embeddings=None, + ) + + +def _workflow(): + return SimpleNamespace( + id=7, + user_id=123, + organization_id=42, + workflow_configurations={"model_overrides": {}}, + ) + + +def _workflow_owner(): + return SimpleNamespace( + id=123, + provider_id="provider-123", + ) + + +def _actor(): + return SimpleNamespace( + id=456, + provider_id="actor-456", + selected_organization_id=42, + ) + + +def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None): + monkeypatch.setattr( + quota_service.db_client, + "get_workflow_by_id", + AsyncMock(return_value=workflow or _workflow()), + ) + monkeypatch.setattr( + quota_service.db_client, + "get_user_by_id", + AsyncMock(return_value=owner or _workflow_owner()), + ) + + +@pytest.mark.asyncio +async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2( + monkeypatch, +): + get_config = AsyncMock(return_value=_dograh_config()) + authorize = AsyncMock( + return_value={ + "allowed": True, + "billing_mode": "v2", + "remaining_credits": "25.0000", + } + ) + check_usage = AsyncMock() + + monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas") + _patch_workflow_context(monkeypatch) + monkeypatch.setattr( + quota_service, + "get_effective_ai_model_configuration_for_workflow", + get_config, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "authorize_workflow_run_start", + authorize, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "check_service_key_usage", + check_usage, + ) + + result = await quota_service.authorize_workflow_run_start(workflow_id=7) + + assert result.has_quota is True + get_config.assert_awaited_once_with( + user_id=123, + organization_id=42, + workflow_configurations={"model_overrides": {}}, + ) + authorize.assert_awaited_once_with( + organization_id=42, + workflow_run_id=None, + service_key=None, + require_correlation_id=False, + minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL, + created_by="provider-123", + metadata={"dograh_user_id": "123", "workflow_id": 7}, + ) + check_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing( + monkeypatch, +): + get_config = AsyncMock(return_value=_byok_config()) + authorize = AsyncMock( + return_value={ + "allowed": False, + "billing_mode": "v2", + "remaining_credits": "0.0000", + "error": "insufficient_credits", + } + ) + check_usage = AsyncMock() + + monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas") + _patch_workflow_context(monkeypatch) + monkeypatch.setattr( + quota_service, + "get_effective_ai_model_configuration_for_workflow", + get_config, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "authorize_workflow_run_start", + authorize, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "check_service_key_usage", + check_usage, + ) + + result = await quota_service.authorize_workflow_run_start(workflow_id=7) + + assert result.has_quota is False + assert result.error_code == "insufficient_credits" + assert "/billing" in result.error_message + assert "founders@dograh.com" not in result.error_message + authorize.assert_awaited_once() + check_usage.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_authorize_workflow_run_v1_uses_legacy_key_usage( + monkeypatch, +): + api_key = "mps_sk_12345678" + get_config = AsyncMock(return_value=_dograh_config(api_key)) + authorize = AsyncMock( + return_value={ + "allowed": True, + "billing_mode": "v1", + "remaining_credits": "0.0000", + } + ) + check_usage = AsyncMock( + return_value={"total_credits_used": 500.0, "remaining_credits": 0.0} + ) + + monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas") + _patch_workflow_context(monkeypatch) + monkeypatch.setattr( + quota_service, + "get_effective_ai_model_configuration_for_workflow", + get_config, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "authorize_workflow_run_start", + authorize, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "check_service_key_usage", + check_usage, + ) + + result = await quota_service.authorize_workflow_run_start(workflow_id=7) + + assert result.has_quota is False + assert result.error_code == "quota_exceeded" + assert "founders@dograh.com" in result.error_message + assert "/billing" not in result.error_message + authorize.assert_awaited_once() + check_usage.assert_awaited_once_with( + api_key, + organization_id=42, + created_by="provider-123", + ) + + +@pytest.mark.asyncio +async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation( + monkeypatch, +): + api_key = "mps_sk_12345678" + workflow_run = SimpleNamespace(initial_context={"existing": "value"}) + get_config = AsyncMock( + return_value=_dograh_config(api_key, managed_service_version=2) + ) + authorize = AsyncMock( + return_value={ + "allowed": True, + "billing_mode": "v2", + "remaining_credits": "25.0000", + "correlation_id": "mps-corr-123", + } + ) + update_workflow_run = AsyncMock() + + monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas") + _patch_workflow_context(monkeypatch) + monkeypatch.setattr( + quota_service.db_client, + "get_workflow_run_by_id", + AsyncMock(return_value=workflow_run), + ) + monkeypatch.setattr( + quota_service.db_client, + "update_workflow_run", + update_workflow_run, + ) + monkeypatch.setattr( + quota_service, + "get_effective_ai_model_configuration_for_workflow", + get_config, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "authorize_workflow_run_start", + authorize, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "check_service_key_usage", + AsyncMock(), + ) + + result = await quota_service.authorize_workflow_run_start( + workflow_id=7, + workflow_run_id=88, + ) + + assert result.has_quota is True + authorize.assert_awaited_once_with( + organization_id=42, + workflow_run_id=88, + service_key=api_key, + require_correlation_id=True, + minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL, + created_by="provider-123", + metadata={"dograh_user_id": "123", "workflow_id": 7}, + ) + update_workflow_run.assert_awaited_once_with( + 88, + initial_context={ + "existing": "value", + MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123", + }, + ) + + +@pytest.mark.asyncio +async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org( + monkeypatch, +): + api_key = "mps_sk_12345678" + workflow_run = SimpleNamespace(initial_context={}) + get_config = AsyncMock( + return_value=_dograh_config(api_key, managed_service_version=2) + ) + hosted_authorize = AsyncMock() + check_usage = AsyncMock( + return_value={"total_credits_used": 1.0, "remaining_credits": 499.0} + ) + create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"}) + update_workflow_run = AsyncMock() + + monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss") + _patch_workflow_context(monkeypatch) + monkeypatch.setattr( + quota_service.db_client, + "get_workflow_run_by_id", + AsyncMock(return_value=workflow_run), + ) + monkeypatch.setattr( + quota_service.db_client, + "update_workflow_run", + update_workflow_run, + ) + monkeypatch.setattr( + quota_service, + "get_effective_ai_model_configuration_for_workflow", + get_config, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "authorize_workflow_run_start", + hosted_authorize, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "check_service_key_usage", + check_usage, + ) + monkeypatch.setattr( + quota_service.mps_service_key_client, + "create_correlation_id", + create_correlation, + ) + + result = await quota_service.authorize_workflow_run_start( + workflow_id=7, + workflow_run_id=88, + ) + + assert result.has_quota is True + hosted_authorize.assert_not_awaited() + check_usage.assert_awaited_once_with( + api_key, + organization_id=None, + created_by="provider-123", + ) + create_correlation.assert_awaited_once_with( + service_key=api_key, + workflow_run_id=88, + ) + update_workflow_run.assert_awaited_once_with( + 88, + initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"}, + ) + + +@pytest.mark.asyncio +async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch): + monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas") + _patch_workflow_context(monkeypatch) + + result = await quota_service.authorize_workflow_run_start( + workflow_id=7, + actor_user=SimpleNamespace(selected_organization_id=999), + ) + + assert result.has_quota is False + assert result.error_code == "workflow_not_found" diff --git a/api/tests/test_telephony_routes.py b/api/tests/test_telephony_routes.py index 76c8a54d..03a4cc48 100644 --- a/api/tests/test_telephony_routes.py +++ b/api/tests/test_telephony_routes.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch from fastapi import FastAPI from fastapi.testclient import TestClient @@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow(): with ( patch("api.routes.telephony.db_client") as mock_db, patch( - "api.routes.telephony.check_dograh_quota_by_user_id", + "api.routes.telephony.authorize_workflow_run_start", new=quota_mock, ), patch( @@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow(): ) assert response.status_code == 200 - quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id) + quota_mock.assert_awaited_once_with( + workflow_id=workflow.id, + workflow_run_id=501, + actor_user=ANY, + ) mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11) create_call = mock_db.create_workflow_run.await_args @@ -119,7 +123,7 @@ def test_initiate_call_uses_organization_preference_phone_number(): with ( patch("api.routes.telephony.db_client") as mock_db, patch( - "api.routes.telephony.check_dograh_quota_by_user_id", + "api.routes.telephony.authorize_workflow_run_start", new=quota_mock, ), patch( @@ -173,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow(): with ( patch("api.routes.telephony.db_client") as mock_db, patch( - "api.routes.telephony.check_dograh_quota_by_user_id", + "api.routes.telephony.authorize_workflow_run_start", new=quota_mock, ), patch( diff --git a/api/tests/test_workflow_text_chat.py b/api/tests/test_workflow_text_chat.py index 3be8a613..40afdcfb 100644 --- a/api/tests/test_workflow_text_chat.py +++ b/api/tests/test_workflow_text_chat.py @@ -1105,7 +1105,7 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run( async with test_client_factory(user) as client: with patch( - "api.routes.workflow_text_chat.check_dograh_quota", + "api.routes.workflow_text_chat.authorize_workflow_run_start", new=AsyncMock( return_value=SimpleNamespace( has_quota=False, @@ -1120,11 +1120,16 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run( assert create_response.status_code == 402 assert create_response.json()["detail"] == "Quota exceeded" - _, total_count = await db_session.get_workflow_runs_by_workflow_id( + runs, total_count = await db_session.get_workflow_runs_by_workflow_id( workflow.id, organization_id=workflow.organization_id, ) - assert total_count == 0 + assert total_count == 1 + text_session = await db_session.get_workflow_run_text_session( + runs[0].id, + organization_id=workflow.organization_id, + ) + assert text_session is None @pytest.mark.asyncio @@ -1168,7 +1173,7 @@ async def test_text_chat_append_rejects_quota_without_mutating_session( async with test_client_factory(user) as client: with ( patch( - "api.routes.workflow_text_chat.check_dograh_quota", + "api.routes.workflow_text_chat.authorize_workflow_run_start", new=AsyncMock( side_effect=[ SimpleNamespace(has_quota=True, error_message=""), diff --git a/ui/src/app/workflow/[workflowId]/components/workflow-tester/EmbeddedVoiceTester.tsx b/ui/src/app/workflow/[workflowId]/components/workflow-tester/EmbeddedVoiceTester.tsx index 96cfb994..9a0ff85d 100644 --- a/ui/src/app/workflow/[workflowId]/components/workflow-tester/EmbeddedVoiceTester.tsx +++ b/ui/src/app/workflow/[workflowId]/components/workflow-tester/EmbeddedVoiceTester.tsx @@ -147,7 +147,7 @@ export function EmbeddedVoiceTester({ onOpenChange={setApiKeyModalOpen} error={apiKeyError} errorCode={apiKeyErrorCode} - onNavigateToCredits={() => router.push("/api-keys")} + onNavigateToBilling={() => router.push("/billing")} onNavigateToModelConfig={() => router.push("/model-configurations")} /> diff --git a/ui/src/app/workflow/[workflowId]/run/[runId]/components/ApiKeyErrorDialog.tsx b/ui/src/app/workflow/[workflowId]/run/[runId]/components/ApiKeyErrorDialog.tsx index 4704d3a5..29672545 100644 --- a/ui/src/app/workflow/[workflowId]/run/[runId]/components/ApiKeyErrorDialog.tsx +++ b/ui/src/app/workflow/[workflowId]/run/[runId]/components/ApiKeyErrorDialog.tsx @@ -8,7 +8,7 @@ interface ApiKeyErrorDialogProps { onOpenChange: (open: boolean) => void; error: string | null; errorCode: string | null; - onNavigateToCredits: () => void; + onNavigateToBilling: () => void; onNavigateToModelConfig: () => void; } @@ -17,15 +17,16 @@ export const ApiKeyErrorDialog = ({ onOpenChange, error, errorCode, - onNavigateToCredits, + onNavigateToBilling, onNavigateToModelConfig, }: ApiKeyErrorDialogProps) => { - const isQuotaError = errorCode === 'quota_exceeded'; + const isBillingCreditsError = errorCode === 'insufficient_credits'; + const isQuotaError = isBillingCreditsError || errorCode === 'quota_exceeded'; const title = isQuotaError ? "Insufficient Credits" : "API Configuration Error"; const icon = isQuotaError ? : ; - const buttonText = isQuotaError ? "Add Credits" : "Go to Model Configurations"; - const onNavigate = isQuotaError ? onNavigateToCredits : onNavigateToModelConfig; + const buttonText = isBillingCreditsError ? "Go to Billing" : "Go to Model Configurations"; + const onNavigate = isBillingCreditsError ? onNavigateToBilling : onNavigateToModelConfig; return ( @@ -40,9 +41,9 @@ export const ApiKeyErrorDialog = ({

{error}

- {isQuotaError && ( + {isBillingCreditsError && (

- Your Dograh service credits are too low to start a call. + Purchase credits from Billing to continue using Dograh-managed models.

)}
diff --git a/ui/src/app/workflow/[workflowId]/run/[runId]/hooks/useWebSocketRTC.tsx b/ui/src/app/workflow/[workflowId]/run/[runId]/hooks/useWebSocketRTC.tsx index 5121fdf9..2aaf65c7 100644 --- a/ui/src/app/workflow/[workflowId]/run/[runId]/hooks/useWebSocketRTC.tsx +++ b/ui/src/app/workflow/[workflowId]/run/[runId]/hooks/useWebSocketRTC.tsx @@ -19,6 +19,13 @@ interface UseWebSocketRTCProps { onNodeTransition?: (transition: ConversationNodeTransitionItem) => void; } +const HANDLED_SERVICE_ERROR_TYPES = new Set([ + 'quota_exceeded', + 'insufficient_credits', + 'invalid_service_key', + 'quota_check_failed', +]); + export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initialContextVariables, onNodeTransition }: UseWebSocketRTCProps) => { const [connectionStatus, setConnectionStatus] = useState<'idle' | 'connecting' | 'connected' | 'failed'>('idle'); const [connectionActive, setConnectionActive] = useState(false); @@ -265,9 +272,7 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia case 'error': // Check if this is a quota/service key error - if (message.payload?.error_type === 'quota_exceeded' || - message.payload?.error_type === 'invalid_service_key' || - message.payload?.error_type === 'quota_check_failed') { + if (HANDLED_SERVICE_ERROR_TYPES.has(message.payload?.error_type)) { // Log as info since it's a handled business logic case logger.info('Quota/service key error, showing user dialog:', message.payload.message);