mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: centralise workflow run authorization
This commit is contained in:
parent
5bf7518829
commit
281656b960
21 changed files with 1036 additions and 252 deletions
|
|
@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunState
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
|
||||
router = APIRouter(prefix="/agent-stream")
|
||||
|
|
@ -67,19 +67,6 @@ async def agent_stream_websocket(
|
|||
await websocket.close(code=1008, reason="Workflow not found")
|
||||
return
|
||||
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
workflow.user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"agent-stream quota exceeded for user {workflow.user_id}: "
|
||||
f"{quota_result.error_message}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=1008, reason=quota_result.error_message or "Quota exceeded"
|
||||
)
|
||||
return
|
||||
|
||||
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
|
||||
workflow_run_name = f"WR-AGS-{numeric_suffix:08d}"
|
||||
call_id = params.get("callId") or params.get("CallSid")
|
||||
|
|
@ -108,6 +95,20 @@ async def agent_stream_websocket(
|
|||
set_current_run_id(workflow_run.id)
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"agent-stream quota exceeded for user {workflow.user_id}: "
|
||||
f"{quota_result.error_message}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=1008, reason=quota_result.error_message or "Quota exceeded"
|
||||
)
|
||||
return
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from api.services.auth.depends import get_user
|
|||
from api.services.campaign.runner import campaign_runner_service
|
||||
from api.services.campaign.source_sync import CampaignSourceSyncService
|
||||
from api.services.campaign.source_sync_factory import get_sync_service
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.reports import generate_campaign_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
|
@ -550,7 +550,10 @@ async def start_campaign(
|
|||
|
||||
# Check Dograh quota before starting campaign (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this campaign will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
@ -872,7 +875,10 @@ async def resume_campaign(
|
|||
|
||||
# Check Dograh quota before resuming campaign (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this campaign will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import TriggerState, WorkflowStatus
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.factory import (
|
||||
get_default_telephony_provider,
|
||||
get_telephony_provider_by_id,
|
||||
|
|
@ -179,14 +179,6 @@ async def _execute_resolved_target(
|
|||
"""Shared execution path once the target workflow has been resolved."""
|
||||
execution_user_id = _get_execution_user_id(target.workflow)
|
||||
|
||||
# Check Dograh quota using the workflow owner's config and model overrides.
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id,
|
||||
workflow_id=target.workflow.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Get telephony provider — either the caller-specified config (validated
|
||||
# against the workflow's org) or the org's default config.
|
||||
if request.telephony_configuration_id is not None:
|
||||
|
|
@ -268,6 +260,15 @@ async def _execute_resolved_target(
|
|||
f"to phone number {request.phone_number}"
|
||||
)
|
||||
|
||||
# Check Dograh quota after the run exists so hosted v2 can mint and store
|
||||
# the MPS correlation id before the provider starts the call.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=target.workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# 9. Construct webhook URL for telephony provider callback
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
|
|||
from api.errors.telephony_errors import TelephonyError
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import (
|
||||
get_all_telephony_providers,
|
||||
|
|
@ -136,14 +136,6 @@ async def initiate_call(
|
|||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
execution_user_id = _get_execution_user_id(workflow)
|
||||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so the keys we will actually use are the ones checked).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
|
|
@ -186,6 +178,16 @@ async def initiate_call(
|
|||
)
|
||||
workflow_run_name = workflow_run.name
|
||||
|
||||
# Check Dograh quota after the run exists so hosted v2 can mint and store
|
||||
# the MPS correlation id before initiating the call.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Construct webhook URL based on provider type
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
|
||||
|
|
@ -739,19 +741,8 @@ async def handle_inbound_run(request: Request):
|
|||
TelephonyError.SIGNATURE_VALIDATION_FAILED
|
||||
)
|
||||
|
||||
# 4. Quota check (use the workflow's model_overrides if set).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# 5. Create workflow run + return provider-shaped response.
|
||||
# 5. Create workflow run + authorize quota before returning provider
|
||||
# stream instructions.
|
||||
workflow_run_id = await _create_inbound_workflow_run(
|
||||
workflow_id,
|
||||
user_id,
|
||||
|
|
@ -760,6 +751,17 @@ async def handle_inbound_run(request: Request):
|
|||
telephony_configuration_id=telephony_configuration_id,
|
||||
from_phone_number_id=phone_row.id,
|
||||
)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
websocket_url = (
|
||||
|
|
@ -874,20 +876,8 @@ async def handle_inbound_telephony(
|
|||
logger.error(f"Request validation failed: {error_type}")
|
||||
return provider_class.generate_validation_error_response(error_type)
|
||||
|
||||
# Check quota before processing (apply per-workflow model_overrides).
|
||||
# Create workflow run.
|
||||
user_id = workflow_context["user_id"]
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run_id = await _create_inbound_workflow_run(
|
||||
workflow_id,
|
||||
workflow_context["user_id"],
|
||||
|
|
@ -896,6 +886,17 @@ async def handle_inbound_telephony(
|
|||
telephony_configuration_id=workflow_context["telephony_configuration_id"],
|
||||
from_phone_number_id=workflow_context.get("from_phone_number_id"),
|
||||
)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# Generate response URLs
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import (
|
|||
register_ws_sender,
|
||||
unregister_ws_sender,
|
||||
)
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
|
||||
router = APIRouter(prefix="/ws")
|
||||
|
||||
|
|
@ -329,7 +329,11 @@ class SignalingManager:
|
|||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this workflow will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
# Send error response for quota issues
|
||||
await ws.send_json(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from api.db import db_client
|
|||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user_with_selected_organization
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
TextChatSessionExecutionError,
|
||||
|
|
@ -96,8 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
async def _ensure_text_chat_quota(
|
||||
user: UserModel,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
) -> None:
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
@ -153,8 +161,6 @@ async def create_text_chat_session(
|
|||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
|
|
@ -170,6 +176,7 @@ async def create_text_chat_session(
|
|||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
set_current_run_id(workflow_run.id)
|
||||
await _ensure_text_chat_quota(user, workflow_id, workflow_run.id)
|
||||
|
||||
annotations = {
|
||||
"tester": {
|
||||
|
|
@ -229,7 +236,7 @@ async def append_text_chat_message(
|
|||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
await _ensure_text_chat_quota(user, workflow_id, run_id)
|
||||
|
||||
try:
|
||||
text_session = await append_text_chat_user_message(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from api.services.campaign.errors import (
|
|||
PhoneNumberPoolExhaustedError,
|
||||
)
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
|
|||
},
|
||||
)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
error_message = quota_result.error_message or "Quota exceeded"
|
||||
logger.warning(
|
||||
f"Campaign {campaign.id} quota check failed for workflow run "
|
||||
f"{workflow_run.id}: {error_message}"
|
||||
)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
is_completed=True,
|
||||
state=WorkflowRunState.COMPLETED.value,
|
||||
gathered_context={"error": error_message},
|
||||
)
|
||||
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
|
||||
if mapping:
|
||||
org_id, mapped_slot_id = mapping
|
||||
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
|
||||
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
|
||||
|
||||
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
|
||||
workflow_run.id
|
||||
)
|
||||
if from_number_mapping:
|
||||
fn_org_id, fn_number, fn_tcid = from_number_mapping
|
||||
await rate_limiter.release_from_number(
|
||||
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
|
||||
)
|
||||
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
|
||||
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Initiate call via telephony provider
|
||||
try:
|
||||
# Construct webhook URL with parameters
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
|
@ -48,27 +45,10 @@ async def ensure_mps_correlation_id(
|
|||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
service_key = _get_dograh_service_api_key(ai_model_config)
|
||||
if not service_key:
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires a Dograh service key before the run starts."
|
||||
)
|
||||
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires workflow run authorization before "
|
||||
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
|
||||
)
|
||||
correlation_id = response.get("correlation_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("MPS correlation-id response did not include correlation_id")
|
||||
|
||||
correlation_id = str(correlation_id)
|
||||
logger.info(
|
||||
"Minted MPS correlation id {} for workflow run {}",
|
||||
correlation_id,
|
||||
workflow_run_id,
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
|
|
@ -78,7 +58,7 @@ def _is_dograh_service(service: Any) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _get_dograh_service_api_key(
|
||||
def get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
|
|
|
|||
|
|
@ -478,6 +478,50 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
service_key: Optional[str] = None,
|
||||
require_correlation_id: bool = False,
|
||||
minimum_credits: float | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
|
||||
payload = {
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"service_key": service_key,
|
||||
"require_correlation_id": require_correlation_id,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if minimum_credits is not None:
|
||||
payload["minimum_credits"] = minimum_credits
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to authorize MPS workflow run start: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to authorize MPS workflow run start: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -5,17 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
get_dograh_service_api_key,
|
||||
uses_managed_model_services_v2,
|
||||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
|
||||
|
||||
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your Dograh credits. "
|
||||
"Please purchase more credits from /billing "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
|
|
@ -26,116 +47,359 @@ class QuotaCheckResult:
|
|||
error_code: str = ""
|
||||
|
||||
|
||||
async def check_dograh_quota(
|
||||
user: UserModel, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check if user has sufficient Dograh quota for making a call.
|
||||
|
||||
This function checks if the user is using any Dograh services (LLM, STT, TTS)
|
||||
and validates that they have sufficient credits remaining.
|
||||
|
||||
When ``workflow_id`` is provided, the workflow's per-workflow
|
||||
``model_overrides`` are merged onto the user's global config so the quota
|
||||
check runs against the credentials that will actually be used for the call
|
||||
(rather than always falling back to the user's defaults).
|
||||
|
||||
Args:
|
||||
user: The user to check quota for
|
||||
workflow_id: Optional workflow whose ``model_overrides`` should be
|
||||
applied when resolving the effective service config.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with has_quota=True if user has sufficient quota or
|
||||
is not using Dograh services, or has_quota=False with error_message
|
||||
if quota is insufficient.
|
||||
"""
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
organization_id = user.selected_organization_id
|
||||
workflow_configurations = None
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
organization_id = workflow.organization_id
|
||||
workflow_configurations = workflow.workflow_configurations
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
workflow_configurations=workflow_configurations,
|
||||
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="insufficient_credits",
|
||||
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _service_uses_dograh(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _dograh_api_keys(user_config: Any) -> set[str]:
|
||||
api_keys: set[str] = set()
|
||||
for section_name in ("llm", "stt", "tts", "embeddings"):
|
||||
service = getattr(user_config, section_name, None)
|
||||
if not _service_uses_dograh(service):
|
||||
continue
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
all_api_keys = [
|
||||
api_key
|
||||
for api_key in service.get_all_api_keys()
|
||||
if isinstance(api_key, str) and api_key
|
||||
]
|
||||
if all_api_keys:
|
||||
api_keys.update(all_api_keys)
|
||||
continue
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if api_key:
|
||||
api_keys.add(api_key)
|
||||
return api_keys
|
||||
|
||||
|
||||
async def _store_run_correlation_id(
|
||||
workflow_run_id: int | None,
|
||||
correlation_id: str | None,
|
||||
) -> None:
|
||||
if not workflow_run_id or not correlation_id:
|
||||
return
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Could not store MPS correlation id for missing workflow run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
initial_context = dict(workflow_run.initial_context or {})
|
||||
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
|
||||
return
|
||||
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
|
||||
async def _authorize_hosted_workflow_run_start(
|
||||
*,
|
||||
workflow_owner: UserModel,
|
||||
organization_id: int | None,
|
||||
workflow_id: int | None,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> tuple[QuotaCheckResult, bool]:
|
||||
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
|
||||
if DEPLOYMENT_MODE == "oss" or organization_id is None:
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
requires_correlation = bool(
|
||||
workflow_run_id and uses_managed_model_services_v2(user_config)
|
||||
)
|
||||
service_key = (
|
||||
get_dograh_service_api_key(user_config) if requires_correlation else None
|
||||
)
|
||||
if requires_correlation and not service_key:
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
dograh_api_keys = set()
|
||||
try:
|
||||
authorization = await mps_service_key_client.authorize_workflow_run_start(
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
service_key=service_key,
|
||||
require_correlation_id=requires_correlation,
|
||||
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by=(
|
||||
str(workflow_owner.provider_id)
|
||||
if workflow_owner.provider_id is not None
|
||||
else None
|
||||
),
|
||||
metadata={
|
||||
"dograh_user_id": str(workflow_owner.id),
|
||||
"workflow_id": workflow_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize workflow start with MPS for org {}: {}",
|
||||
organization_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.llm.api_key)
|
||||
billing_mode = authorization.get("billing_mode")
|
||||
if billing_mode != "v2":
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.stt.api_key)
|
||||
remaining = _safe_float(authorization.get("remaining_credits"))
|
||||
if (
|
||||
not authorization.get("allowed", False)
|
||||
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
|
||||
):
|
||||
logger.warning(
|
||||
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return _insufficient_billing_v2_quota_result(), True
|
||||
|
||||
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
try:
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
authorization.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to store MPS correlation id for workflow_run_id {}: {}",
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
logger.info(
|
||||
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return QuotaCheckResult(has_quota=True), True
|
||||
|
||||
if (
|
||||
user_config.embeddings
|
||||
and user_config.embeddings.provider == ServiceProviders.DOGRAH
|
||||
):
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.embeddings.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
async def _authorize_legacy_dograh_keys(
|
||||
*,
|
||||
dograh_api_keys: set[str],
|
||||
organization_id: int | None,
|
||||
workflow_owner: UserModel,
|
||||
) -> QuotaCheckResult:
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=workflow_owner.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Check quota for ALL Dograh keys
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
return _insufficient_legacy_quota_result()
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < 0.10:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=(
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def _authorize_oss_managed_v2_correlation(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> QuotaCheckResult:
|
||||
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
service_key = get_dograh_service_api_key(user_config)
|
||||
if not service_key:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
response.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
actor_user: UserModel | None = None,
|
||||
) -> QuotaCheckResult:
|
||||
"""Authorize a workflow run before any billable call/text runtime starts.
|
||||
|
||||
The workflow organization is the billing subject for hosted v2. The workflow
|
||||
owner is used only to resolve the effective model configuration and legacy
|
||||
service-key metadata.
|
||||
"""
|
||||
try:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if not workflow:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
actor_org_id = getattr(actor_user, "selected_organization_id", None)
|
||||
if actor_org_id is not None and actor_org_id != workflow.organization_id:
|
||||
logger.warning(
|
||||
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
|
||||
actor_org_id,
|
||||
workflow_id,
|
||||
workflow.organization_id,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
|
||||
if not workflow_owner:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="user_not_found",
|
||||
error_message="User not found",
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_owner.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=workflow.workflow_configurations,
|
||||
)
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
|
||||
workflow_owner=workflow_owner,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
if hosted_enforced or not hosted_result.has_quota:
|
||||
return hosted_result
|
||||
|
||||
dograh_api_keys = _dograh_api_keys(user_config)
|
||||
if not dograh_api_keys:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
legacy_result = await _authorize_legacy_dograh_keys(
|
||||
dograh_api_keys=dograh_api_keys,
|
||||
organization_id=(
|
||||
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
|
||||
),
|
||||
workflow_owner=workflow_owner,
|
||||
)
|
||||
if not legacy_result.has_quota:
|
||||
return legacy_result
|
||||
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return await _authorize_oss_managed_v2_correlation(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
|
@ -143,30 +407,3 @@ async def check_dograh_quota(
|
|||
logger.error(f"Error during quota check: {str(e)}")
|
||||
# On unexpected error, allow the call to proceed
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def check_dograh_quota_by_user_id(
|
||||
user_id: int, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check Dograh quota by user ID.
|
||||
|
||||
Convenience function that fetches the user and then checks quota. When
|
||||
``workflow_id`` is provided, the workflow's ``model_overrides`` are
|
||||
applied so the quota check evaluates the credentials that will actually
|
||||
be used for the call.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check quota for
|
||||
workflow_id: Optional workflow whose per-workflow overrides should
|
||||
be applied to the user's config before checking quota.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with quota status
|
||||
"""
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message="User not found",
|
||||
)
|
||||
return await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from loguru import logger
|
|||
from api.constants import REDIS_URL
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunMode
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.transfer_event_protocol import (
|
||||
TransferEvent,
|
||||
|
|
@ -564,19 +564,7 @@ class ARIConnection:
|
|||
|
||||
user_id = workflow.user_id
|
||||
|
||||
# 3. Check quota (apply per-workflow model_overrides).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=inbound_workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 4. Create workflow run
|
||||
# 3. Create workflow run
|
||||
call_id = channel_id
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=f"ARI Inbound {caller_number}",
|
||||
|
|
@ -602,6 +590,20 @@ class ARIConnection:
|
|||
f"(caller={caller_number}, called={called_number})"
|
||||
)
|
||||
|
||||
# 4. Check quota after the run exists so hosted v2 can mint and
|
||||
# store the MPS correlation id before the pipeline starts.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=inbound_workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 5. Answer the inbound channel
|
||||
await self._answer_channel(channel_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ def test_create_cartesia_tts_service_passes_selected_model():
|
|||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.CartesiaTTSService") as mock_service:
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.CartesiaTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
|
|
|
|||
|
|
@ -270,6 +270,12 @@ class TestDispatcherThreadsTelephonyConfig:
|
|||
"api.services.campaign.campaign_call_dispatcher.get_backend_endpoints",
|
||||
AsyncMock(return_value=("https://example.com", None)),
|
||||
),
|
||||
patch(
|
||||
"api.services.campaign.campaign_call_dispatcher.authorize_workflow_run_start",
|
||||
AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_db.get_workflow_by_id = AsyncMock(return_value=SimpleNamespace(id=1))
|
||||
mock_db.create_workflow_run = AsyncMock(return_value=workflow_run)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,76 @@ async def test_get_billing_account_status_uses_hosted_org_auth(monkeypatch):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_start_uses_hosted_org_auth(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, timeout):
|
||||
self.timeout = timeout
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def post(self, url, json, headers):
|
||||
calls.append(("POST", url, json, headers))
|
||||
return _Response(
|
||||
200,
|
||||
{
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.httpx.AsyncClient", FakeAsyncClient
|
||||
)
|
||||
monkeypatch.setattr("api.services.mps_service_key_client.DEPLOYMENT_MODE", "saas")
|
||||
monkeypatch.setattr(
|
||||
"api.services.mps_service_key_client.DOGRAH_MPS_SECRET_KEY", "mps-secret"
|
||||
)
|
||||
|
||||
client = MPSServiceKeyClient()
|
||||
|
||||
assert await client.authorize_workflow_run_start(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key="mps_sk_paid",
|
||||
require_correlation_id=True,
|
||||
minimum_credits=0.1,
|
||||
metadata={"workflow_id": 7},
|
||||
created_by="provider-123",
|
||||
) == {
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
}
|
||||
assert calls == [
|
||||
(
|
||||
"POST",
|
||||
f"{client.base_url}/api/v1/billing/accounts/42/run-authorization",
|
||||
{
|
||||
"workflow_run_id": 88,
|
||||
"service_key": "mps_sk_paid",
|
||||
"require_correlation_id": True,
|
||||
"minimum_credits": 0.1,
|
||||
"metadata": {"workflow_id": 7},
|
||||
},
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"X-Secret-Key": "mps-secret",
|
||||
"X-Organization-Id": "42",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_billing_account_v2_uses_balance_endpoint(monkeypatch):
|
||||
calls = []
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def test_trigger_route_executes_as_workflow_owner():
|
|||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
"api.routes.public_agent.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -92,7 +92,10 @@ def test_trigger_route_executes_as_workflow_owner():
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
quota_mock.assert_awaited_once_with(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=501,
|
||||
)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_kwargs = mock_db.create_workflow_run.await_args.kwargs
|
||||
|
|
@ -124,7 +127,7 @@ def test_workflow_uuid_route_uses_scoped_lookup_and_shared_execution():
|
|||
with (
|
||||
patch("api.routes.public_agent.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.public_agent.check_dograh_quota_by_user_id",
|
||||
"api.routes.public_agent.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
|
|||
369
api/tests/test_quota_service.py
Normal file
369
api/tests/test_quota_service.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from api.services import quota_service
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.managed_model_services import MPS_CORRELATION_ID_CONTEXT_KEY
|
||||
|
||||
|
||||
def _dograh_config(
|
||||
api_key: str = "mps_sk_12345678",
|
||||
*,
|
||||
managed_service_version: int = 1,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
managed_service_version=managed_service_version,
|
||||
llm=SimpleNamespace(provider=ServiceProviders.DOGRAH, api_key=api_key),
|
||||
stt=None,
|
||||
tts=None,
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
def _byok_config():
|
||||
return SimpleNamespace(
|
||||
managed_service_version=2,
|
||||
llm=SimpleNamespace(provider="openai", api_key="sk-openai"),
|
||||
stt=None,
|
||||
tts=None,
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
def _workflow():
|
||||
return SimpleNamespace(
|
||||
id=7,
|
||||
user_id=123,
|
||||
organization_id=42,
|
||||
workflow_configurations={"model_overrides": {}},
|
||||
)
|
||||
|
||||
|
||||
def _workflow_owner():
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
provider_id="provider-123",
|
||||
)
|
||||
|
||||
|
||||
def _actor():
|
||||
return SimpleNamespace(
|
||||
id=456,
|
||||
provider_id="actor-456",
|
||||
selected_organization_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _patch_workflow_context(monkeypatch, *, workflow=None, owner=None):
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_by_id",
|
||||
AsyncMock(return_value=workflow or _workflow()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_user_by_id",
|
||||
AsyncMock(return_value=owner or _workflow_owner()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_uses_workflow_org_for_hosted_v2(
|
||||
monkeypatch,
|
||||
):
|
||||
get_config = AsyncMock(return_value=_dograh_config())
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is True
|
||||
get_config.assert_awaited_once_with(
|
||||
user_id=123,
|
||||
organization_id=42,
|
||||
workflow_configurations={"model_overrides": {}},
|
||||
)
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=None,
|
||||
service_key=None,
|
||||
require_correlation_id=False,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
check_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_v2_insufficient_credits_prompts_billing(
|
||||
monkeypatch,
|
||||
):
|
||||
get_config = AsyncMock(return_value=_byok_config())
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": False,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "0.0000",
|
||||
"error": "insufficient_credits",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "insufficient_credits"
|
||||
assert "/billing" in result.error_message
|
||||
assert "founders@dograh.com" not in result.error_message
|
||||
authorize.assert_awaited_once()
|
||||
check_usage.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_v1_uses_legacy_key_usage(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
get_config = AsyncMock(return_value=_dograh_config(api_key))
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v1",
|
||||
"remaining_credits": "0.0000",
|
||||
}
|
||||
)
|
||||
check_usage = AsyncMock(
|
||||
return_value={"total_credits_used": 500.0, "remaining_credits": 0.0}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(workflow_id=7)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "quota_exceeded"
|
||||
assert "founders@dograh.com" in result.error_message
|
||||
assert "/billing" not in result.error_message
|
||||
authorize.assert_awaited_once()
|
||||
check_usage.assert_awaited_once_with(
|
||||
api_key,
|
||||
organization_id=42,
|
||||
created_by="provider-123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_managed_v2_stores_hosted_correlation(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
workflow_run = SimpleNamespace(initial_context={"existing": "value"})
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
authorize = AsyncMock(
|
||||
return_value={
|
||||
"allowed": True,
|
||||
"billing_mode": "v2",
|
||||
"remaining_credits": "25.0000",
|
||||
"correlation_id": "mps-corr-123",
|
||||
}
|
||||
)
|
||||
update_workflow_run = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"update_workflow_run",
|
||||
update_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
AsyncMock(),
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is True
|
||||
authorize.assert_awaited_once_with(
|
||||
organization_id=42,
|
||||
workflow_run_id=88,
|
||||
service_key=api_key,
|
||||
require_correlation_id=True,
|
||||
minimum_credits=quota_service.MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by="provider-123",
|
||||
metadata={"dograh_user_id": "123", "workflow_id": 7},
|
||||
)
|
||||
update_workflow_run.assert_awaited_once_with(
|
||||
88,
|
||||
initial_context={
|
||||
"existing": "value",
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY: "mps-corr-123",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_oss_uses_key_paths_not_workflow_org(
|
||||
monkeypatch,
|
||||
):
|
||||
api_key = "mps_sk_12345678"
|
||||
workflow_run = SimpleNamespace(initial_context={})
|
||||
get_config = AsyncMock(
|
||||
return_value=_dograh_config(api_key, managed_service_version=2)
|
||||
)
|
||||
hosted_authorize = AsyncMock()
|
||||
check_usage = AsyncMock(
|
||||
return_value={"total_credits_used": 1.0, "remaining_credits": 499.0}
|
||||
)
|
||||
create_correlation = AsyncMock(return_value={"correlation_id": "oss-corr-123"})
|
||||
update_workflow_run = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "oss")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"get_workflow_run_by_id",
|
||||
AsyncMock(return_value=workflow_run),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.db_client,
|
||||
"update_workflow_run",
|
||||
update_workflow_run,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service,
|
||||
"get_effective_ai_model_configuration_for_workflow",
|
||||
get_config,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"authorize_workflow_run_start",
|
||||
hosted_authorize,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"check_service_key_usage",
|
||||
check_usage,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
quota_service.mps_service_key_client,
|
||||
"create_correlation_id",
|
||||
create_correlation,
|
||||
)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
|
||||
assert result.has_quota is True
|
||||
hosted_authorize.assert_not_awaited()
|
||||
check_usage.assert_awaited_once_with(
|
||||
api_key,
|
||||
organization_id=None,
|
||||
created_by="provider-123",
|
||||
)
|
||||
create_correlation.assert_awaited_once_with(
|
||||
service_key=api_key,
|
||||
workflow_run_id=88,
|
||||
)
|
||||
update_workflow_run.assert_awaited_once_with(
|
||||
88,
|
||||
initial_context={MPS_CORRELATION_ID_CONTEXT_KEY: "oss-corr-123"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_workflow_run_rejects_actor_from_another_org(monkeypatch):
|
||||
monkeypatch.setattr(quota_service, "DEPLOYMENT_MODE", "saas")
|
||||
_patch_workflow_context(monkeypatch)
|
||||
|
||||
result = await quota_service.authorize_workflow_run_start(
|
||||
workflow_id=7,
|
||||
actor_user=SimpleNamespace(selected_organization_id=999),
|
||||
)
|
||||
|
||||
assert result.has_quota is False
|
||||
assert result.error_code == "workflow_not_found"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -54,7 +54,7 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -88,7 +88,11 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
quota_mock.assert_awaited_once_with(workflow.user_id, workflow_id=workflow.id)
|
||||
quota_mock.assert_awaited_once_with(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=501,
|
||||
actor_user=ANY,
|
||||
)
|
||||
mock_db.get_workflow.assert_awaited_once_with(workflow.id, organization_id=11)
|
||||
|
||||
create_call = mock_db.create_workflow_run.await_args
|
||||
|
|
@ -119,7 +123,7 @@ def test_initiate_call_uses_organization_preference_phone_number():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
@ -173,7 +177,7 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
|
|||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
"api.routes.telephony.authorize_workflow_run_start",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
|
|
|
|||
|
|
@ -1105,7 +1105,7 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
|
|||
|
||||
async with test_client_factory(user) as client:
|
||||
with patch(
|
||||
"api.routes.workflow_text_chat.check_dograh_quota",
|
||||
"api.routes.workflow_text_chat.authorize_workflow_run_start",
|
||||
new=AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
has_quota=False,
|
||||
|
|
@ -1120,11 +1120,16 @@ async def test_text_chat_session_creation_rejects_quota_before_creating_run(
|
|||
|
||||
assert create_response.status_code == 402
|
||||
assert create_response.json()["detail"] == "Quota exceeded"
|
||||
_, total_count = await db_session.get_workflow_runs_by_workflow_id(
|
||||
runs, total_count = await db_session.get_workflow_runs_by_workflow_id(
|
||||
workflow.id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
assert total_count == 0
|
||||
assert total_count == 1
|
||||
text_session = await db_session.get_workflow_run_text_session(
|
||||
runs[0].id,
|
||||
organization_id=workflow.organization_id,
|
||||
)
|
||||
assert text_session is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1168,7 +1173,7 @@ async def test_text_chat_append_rejects_quota_without_mutating_session(
|
|||
async with test_client_factory(user) as client:
|
||||
with (
|
||||
patch(
|
||||
"api.routes.workflow_text_chat.check_dograh_quota",
|
||||
"api.routes.workflow_text_chat.authorize_workflow_run_start",
|
||||
new=AsyncMock(
|
||||
side_effect=[
|
||||
SimpleNamespace(has_quota=True, error_message=""),
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ export function EmbeddedVoiceTester({
|
|||
onOpenChange={setApiKeyModalOpen}
|
||||
error={apiKeyError}
|
||||
errorCode={apiKeyErrorCode}
|
||||
onNavigateToCredits={() => router.push("/api-keys")}
|
||||
onNavigateToBilling={() => router.push("/billing")}
|
||||
onNavigateToModelConfig={() => router.push("/model-configurations")}
|
||||
/>
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ interface ApiKeyErrorDialogProps {
|
|||
onOpenChange: (open: boolean) => void;
|
||||
error: string | null;
|
||||
errorCode: string | null;
|
||||
onNavigateToCredits: () => void;
|
||||
onNavigateToBilling: () => void;
|
||||
onNavigateToModelConfig: () => void;
|
||||
}
|
||||
|
||||
|
|
@ -17,15 +17,16 @@ export const ApiKeyErrorDialog = ({
|
|||
onOpenChange,
|
||||
error,
|
||||
errorCode,
|
||||
onNavigateToCredits,
|
||||
onNavigateToBilling,
|
||||
onNavigateToModelConfig,
|
||||
}: ApiKeyErrorDialogProps) => {
|
||||
const isQuotaError = errorCode === 'quota_exceeded';
|
||||
const isBillingCreditsError = errorCode === 'insufficient_credits';
|
||||
const isQuotaError = isBillingCreditsError || errorCode === 'quota_exceeded';
|
||||
|
||||
const title = isQuotaError ? "Insufficient Credits" : "API Configuration Error";
|
||||
const icon = isQuotaError ? <CreditCard className="h-5 w-5 text-orange-500" /> : <Key className="h-5 w-5 text-red-500" />;
|
||||
const buttonText = isQuotaError ? "Add Credits" : "Go to Model Configurations";
|
||||
const onNavigate = isQuotaError ? onNavigateToCredits : onNavigateToModelConfig;
|
||||
const buttonText = isBillingCreditsError ? "Go to Billing" : "Go to Model Configurations";
|
||||
const onNavigate = isBillingCreditsError ? onNavigateToBilling : onNavigateToModelConfig;
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
|
|
@ -40,9 +41,9 @@ export const ApiKeyErrorDialog = ({
|
|||
<AlertCircle className="h-4 w-4 text-muted-foreground mt-0.5 flex-shrink-0" />
|
||||
<div className="text-sm space-y-1">
|
||||
<p className="font-medium text-foreground">{error}</p>
|
||||
{isQuotaError && (
|
||||
{isBillingCreditsError && (
|
||||
<p className="text-muted-foreground">
|
||||
Your Dograh service credits are too low to start a call.
|
||||
Purchase credits from Billing to continue using Dograh-managed models.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -19,6 +19,13 @@ interface UseWebSocketRTCProps {
|
|||
onNodeTransition?: (transition: ConversationNodeTransitionItem) => void;
|
||||
}
|
||||
|
||||
const HANDLED_SERVICE_ERROR_TYPES = new Set([
|
||||
'quota_exceeded',
|
||||
'insufficient_credits',
|
||||
'invalid_service_key',
|
||||
'quota_check_failed',
|
||||
]);
|
||||
|
||||
export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initialContextVariables, onNodeTransition }: UseWebSocketRTCProps) => {
|
||||
const [connectionStatus, setConnectionStatus] = useState<'idle' | 'connecting' | 'connected' | 'failed'>('idle');
|
||||
const [connectionActive, setConnectionActive] = useState(false);
|
||||
|
|
@ -265,9 +272,7 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
|
|||
|
||||
case 'error':
|
||||
// Check if this is a quota/service key error
|
||||
if (message.payload?.error_type === 'quota_exceeded' ||
|
||||
message.payload?.error_type === 'invalid_service_key' ||
|
||||
message.payload?.error_type === 'quota_check_failed') {
|
||||
if (HANDLED_SERVICE_ERROR_TYPES.has(message.payload?.error_type)) {
|
||||
// Log as info since it's a handled business logic case
|
||||
logger.info('Quota/service key error, showing user dialog:', message.payload.message);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue