mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
feat: centralise workflow run authorization
This commit is contained in:
parent
5bf7518829
commit
281656b960
21 changed files with 1036 additions and 252 deletions
|
|
@ -15,6 +15,7 @@ from api.services.campaign.errors import (
|
|||
PhoneNumberPoolExhaustedError,
|
||||
)
|
||||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -339,6 +340,41 @@ class CampaignCallDispatcher:
|
|||
},
|
||||
)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
error_message = quota_result.error_message or "Quota exceeded"
|
||||
logger.warning(
|
||||
f"Campaign {campaign.id} quota check failed for workflow run "
|
||||
f"{workflow_run.id}: {error_message}"
|
||||
)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id,
|
||||
is_completed=True,
|
||||
state=WorkflowRunState.COMPLETED.value,
|
||||
gathered_context={"error": error_message},
|
||||
)
|
||||
|
||||
mapping = await rate_limiter.get_workflow_slot_mapping(workflow_run.id)
|
||||
if mapping:
|
||||
org_id, mapped_slot_id = mapping
|
||||
await rate_limiter.release_concurrent_slot(org_id, mapped_slot_id)
|
||||
await rate_limiter.delete_workflow_slot_mapping(workflow_run.id)
|
||||
|
||||
from_number_mapping = await rate_limiter.get_workflow_from_number_mapping(
|
||||
workflow_run.id
|
||||
)
|
||||
if from_number_mapping:
|
||||
fn_org_id, fn_number, fn_tcid = from_number_mapping
|
||||
await rate_limiter.release_from_number(
|
||||
fn_org_id, fn_number, telephony_configuration_id=fn_tcid
|
||||
)
|
||||
await rate_limiter.delete_workflow_from_number_mapping(workflow_run.id)
|
||||
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Initiate call via telephony provider
|
||||
try:
|
||||
# Construct webhook URL with parameters
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY = "mps_correlation_id"
|
||||
|
||||
|
|
@ -48,27 +45,10 @@ async def ensure_mps_correlation_id(
|
|||
if not uses_managed_model_services_v2(ai_model_config):
|
||||
return None
|
||||
|
||||
service_key = _get_dograh_service_api_key(ai_model_config)
|
||||
if not service_key:
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires a Dograh service key before the run starts."
|
||||
)
|
||||
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
raise ValueError(
|
||||
"Managed model services v2 requires workflow run authorization before "
|
||||
f"the run starts. Missing correlation id for workflow_run_id={workflow_run_id}."
|
||||
)
|
||||
correlation_id = response.get("correlation_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("MPS correlation-id response did not include correlation_id")
|
||||
|
||||
correlation_id = str(correlation_id)
|
||||
logger.info(
|
||||
"Minted MPS correlation id {} for workflow run {}",
|
||||
correlation_id,
|
||||
workflow_run_id,
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
def _is_dograh_service(service: Any) -> bool:
|
||||
|
|
@ -78,7 +58,7 @@ def _is_dograh_service(service: Any) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _get_dograh_service_api_key(
|
||||
def get_dograh_service_api_key(
|
||||
ai_model_config: EffectiveAIModelConfiguration,
|
||||
) -> str | None:
|
||||
for section_name in ("llm", "tts", "stt", "embeddings"):
|
||||
|
|
|
|||
|
|
@ -478,6 +478,50 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
organization_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
service_key: Optional[str] = None,
|
||||
require_correlation_id: bool = False,
|
||||
minimum_credits: float | None = None,
|
||||
metadata: Optional[dict] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Authorize a hosted workflow run and optionally mint its MPS correlation."""
|
||||
payload = {
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"service_key": service_key,
|
||||
"require_correlation_id": require_correlation_id,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
if minimum_credits is not None:
|
||||
payload["minimum_credits"] = minimum_credits
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/billing/accounts/{organization_id}/run-authorization",
|
||||
json=payload,
|
||||
headers=self._get_headers(
|
||||
organization_id=organization_id,
|
||||
created_by=created_by,
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
||||
logger.error(
|
||||
"Failed to authorize MPS workflow run start: "
|
||||
f"{response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to authorize MPS workflow run start: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def create_correlation_id(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -5,17 +5,38 @@ across different endpoints (WebRTC signaling, telephony, public API triggers).
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.managed_model_services import (
|
||||
MPS_CORRELATION_ID_CONTEXT_KEY,
|
||||
get_dograh_service_api_key,
|
||||
uses_managed_model_services_v2,
|
||||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
MINIMUM_DOGRAH_CREDITS_FOR_CALL = 0.10
|
||||
|
||||
LEGACY_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
BILLING_V2_QUOTA_EXCEEDED_MESSAGE = (
|
||||
"You have exhausted your Dograh credits. "
|
||||
"Please purchase more credits from /billing "
|
||||
"or change providers in Models configurations."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
|
|
@ -26,116 +47,359 @@ class QuotaCheckResult:
|
|||
error_code: str = ""
|
||||
|
||||
|
||||
async def check_dograh_quota(
|
||||
user: UserModel, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check if user has sufficient Dograh quota for making a call.
|
||||
|
||||
This function checks if the user is using any Dograh services (LLM, STT, TTS)
|
||||
and validates that they have sufficient credits remaining.
|
||||
|
||||
When ``workflow_id`` is provided, the workflow's per-workflow
|
||||
``model_overrides`` are merged onto the user's global config so the quota
|
||||
check runs against the credentials that will actually be used for the call
|
||||
(rather than always falling back to the user's defaults).
|
||||
|
||||
Args:
|
||||
user: The user to check quota for
|
||||
workflow_id: Optional workflow whose ``model_overrides`` should be
|
||||
applied when resolving the effective service config.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with has_quota=True if user has sufficient quota or
|
||||
is not using Dograh services, or has_quota=False with error_message
|
||||
if quota is insufficient.
|
||||
"""
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
organization_id = user.selected_organization_id
|
||||
workflow_configurations = None
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
organization_id = workflow.organization_id
|
||||
workflow_configurations = workflow.workflow_configurations
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
workflow_configurations=workflow_configurations,
|
||||
def _insufficient_billing_v2_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="insufficient_credits",
|
||||
error_message=BILLING_V2_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _insufficient_legacy_quota_result() -> QuotaCheckResult:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=LEGACY_QUOTA_EXCEEDED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _service_uses_dograh(service: Any) -> bool:
|
||||
provider = getattr(service, "provider", None)
|
||||
return (
|
||||
provider == ServiceProviders.DOGRAH or provider == ServiceProviders.DOGRAH.value
|
||||
)
|
||||
|
||||
|
||||
def _dograh_api_keys(user_config: Any) -> set[str]:
|
||||
api_keys: set[str] = set()
|
||||
for section_name in ("llm", "stt", "tts", "embeddings"):
|
||||
service = getattr(user_config, section_name, None)
|
||||
if not _service_uses_dograh(service):
|
||||
continue
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
all_api_keys = [
|
||||
api_key
|
||||
for api_key in service.get_all_api_keys()
|
||||
if isinstance(api_key, str) and api_key
|
||||
]
|
||||
if all_api_keys:
|
||||
api_keys.update(all_api_keys)
|
||||
continue
|
||||
api_key = getattr(service, "api_key", None)
|
||||
if api_key:
|
||||
api_keys.add(api_key)
|
||||
return api_keys
|
||||
|
||||
|
||||
async def _store_run_correlation_id(
|
||||
workflow_run_id: int | None,
|
||||
correlation_id: str | None,
|
||||
) -> None:
|
||||
if not workflow_run_id or not correlation_id:
|
||||
return
|
||||
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if not workflow_run:
|
||||
logger.warning(
|
||||
"Could not store MPS correlation id for missing workflow run {}",
|
||||
workflow_run_id,
|
||||
)
|
||||
return
|
||||
|
||||
initial_context = dict(workflow_run.initial_context or {})
|
||||
if initial_context.get(MPS_CORRELATION_ID_CONTEXT_KEY) == correlation_id:
|
||||
return
|
||||
|
||||
initial_context[MPS_CORRELATION_ID_CONTEXT_KEY] = correlation_id
|
||||
await db_client.update_workflow_run(
|
||||
workflow_run_id,
|
||||
initial_context=initial_context,
|
||||
)
|
||||
|
||||
|
||||
async def _authorize_hosted_workflow_run_start(
|
||||
*,
|
||||
workflow_owner: UserModel,
|
||||
organization_id: int | None,
|
||||
workflow_id: int | None,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> tuple[QuotaCheckResult, bool]:
|
||||
"""Authorize hosted v2 billing and return whether MPS handled enforcement."""
|
||||
if DEPLOYMENT_MODE == "oss" or organization_id is None:
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
requires_correlation = bool(
|
||||
workflow_run_id and uses_managed_model_services_v2(user_config)
|
||||
)
|
||||
service_key = (
|
||||
get_dograh_service_api_key(user_config) if requires_correlation else None
|
||||
)
|
||||
if requires_correlation and not service_key:
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
dograh_api_keys = set()
|
||||
try:
|
||||
authorization = await mps_service_key_client.authorize_workflow_run_start(
|
||||
organization_id=organization_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
service_key=service_key,
|
||||
require_correlation_id=requires_correlation,
|
||||
minimum_credits=MINIMUM_DOGRAH_CREDITS_FOR_CALL,
|
||||
created_by=(
|
||||
str(workflow_owner.provider_id)
|
||||
if workflow_owner.provider_id is not None
|
||||
else None
|
||||
),
|
||||
metadata={
|
||||
"dograh_user_id": str(workflow_owner.id),
|
||||
"workflow_id": workflow_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize workflow start with MPS for org {}: {}",
|
||||
organization_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.llm.api_key)
|
||||
billing_mode = authorization.get("billing_mode")
|
||||
if billing_mode != "v2":
|
||||
return QuotaCheckResult(has_quota=True), False
|
||||
|
||||
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.stt.api_key)
|
||||
remaining = _safe_float(authorization.get("remaining_credits"))
|
||||
if (
|
||||
not authorization.get("allowed", False)
|
||||
or remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL
|
||||
):
|
||||
logger.warning(
|
||||
"Insufficient Dograh billing v2 credits for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return _insufficient_billing_v2_quota_result(), True
|
||||
|
||||
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
try:
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
authorization.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to store MPS correlation id for workflow_run_id {}: {}",
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return (
|
||||
QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
),
|
||||
True,
|
||||
)
|
||||
logger.info(
|
||||
"Dograh billing v2 run authorization passed for org {}: {:.2f} credits remaining",
|
||||
organization_id,
|
||||
remaining,
|
||||
)
|
||||
return QuotaCheckResult(has_quota=True), True
|
||||
|
||||
if (
|
||||
user_config.embeddings
|
||||
and user_config.embeddings.provider == ServiceProviders.DOGRAH
|
||||
):
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.embeddings.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
async def _authorize_legacy_dograh_keys(
|
||||
*,
|
||||
dograh_api_keys: set[str],
|
||||
organization_id: int | None,
|
||||
workflow_owner: UserModel,
|
||||
) -> QuotaCheckResult:
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=workflow_owner.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Check quota for ALL Dograh keys
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < MINIMUM_DOGRAH_CREDITS_FOR_CALL:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
return _insufficient_legacy_quota_result()
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < 0.10:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_exceeded",
|
||||
error_message=(
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"{remaining:.2f} credits remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "not found" in error_str.lower():
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
error_code="invalid_service_key",
|
||||
error_message="You have invalid keys in your model configuration. Please validate the service keys.",
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def _authorize_oss_managed_v2_correlation(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None,
|
||||
user_config: Any,
|
||||
) -> QuotaCheckResult:
|
||||
if not workflow_run_id or not uses_managed_model_services_v2(user_config):
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
service_key = get_dograh_service_api_key(user_config)
|
||||
if not service_key:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="invalid_service_key",
|
||||
error_message=(
|
||||
"You have invalid keys in your model configuration. "
|
||||
"Please validate the service keys."
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await mps_service_key_client.create_correlation_id(
|
||||
service_key=service_key,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
await _store_run_correlation_id(
|
||||
workflow_run_id,
|
||||
response.get("correlation_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to authorize OSS managed v2 workflow start for workflow {} run {}: {}",
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
e,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="quota_check_failed",
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def authorize_workflow_run_start(
|
||||
*,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int | None = None,
|
||||
actor_user: UserModel | None = None,
|
||||
) -> QuotaCheckResult:
|
||||
"""Authorize a workflow run before any billable call/text runtime starts.
|
||||
|
||||
The workflow organization is the billing subject for hosted v2. The workflow
|
||||
owner is used only to resolve the effective model configuration and legacy
|
||||
service-key metadata.
|
||||
"""
|
||||
try:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if not workflow:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
actor_org_id = getattr(actor_user, "selected_organization_id", None)
|
||||
if actor_org_id is not None and actor_org_id != workflow.organization_id:
|
||||
logger.warning(
|
||||
"Workflow start authorization denied: actor org {} does not match workflow {} org {}",
|
||||
actor_org_id,
|
||||
workflow_id,
|
||||
workflow.organization_id,
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="workflow_not_found",
|
||||
error_message="Workflow not found",
|
||||
)
|
||||
|
||||
workflow_owner = await db_client.get_user_by_id(workflow.user_id)
|
||||
if not workflow_owner:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_code="user_not_found",
|
||||
error_message="User not found",
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_owner.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=workflow.workflow_configurations,
|
||||
)
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
hosted_result, hosted_enforced = await _authorize_hosted_workflow_run_start(
|
||||
workflow_owner=workflow_owner,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
if hosted_enforced or not hosted_result.has_quota:
|
||||
return hosted_result
|
||||
|
||||
dograh_api_keys = _dograh_api_keys(user_config)
|
||||
if not dograh_api_keys:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
legacy_result = await _authorize_legacy_dograh_keys(
|
||||
dograh_api_keys=dograh_api_keys,
|
||||
organization_id=(
|
||||
None if DEPLOYMENT_MODE == "oss" else workflow.organization_id
|
||||
),
|
||||
workflow_owner=workflow_owner,
|
||||
)
|
||||
if not legacy_result.has_quota:
|
||||
return legacy_result
|
||||
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return await _authorize_oss_managed_v2_correlation(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_config=user_config,
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
|
@ -143,30 +407,3 @@ async def check_dograh_quota(
|
|||
logger.error(f"Error during quota check: {str(e)}")
|
||||
# On unexpected error, allow the call to proceed
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def check_dograh_quota_by_user_id(
|
||||
user_id: int, workflow_id: int | None = None
|
||||
) -> QuotaCheckResult:
|
||||
"""Check Dograh quota by user ID.
|
||||
|
||||
Convenience function that fetches the user and then checks quota. When
|
||||
``workflow_id`` is provided, the workflow's ``model_overrides`` are
|
||||
applied so the quota check evaluates the credentials that will actually
|
||||
be used for the call.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check quota for
|
||||
workflow_id: Optional workflow whose per-workflow overrides should
|
||||
be applied to the user's config before checking quota.
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with quota status
|
||||
"""
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message="User not found",
|
||||
)
|
||||
return await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from loguru import logger
|
|||
from api.constants import REDIS_URL
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunMode
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.transfer_event_protocol import (
|
||||
TransferEvent,
|
||||
|
|
@ -564,19 +564,7 @@ class ARIConnection:
|
|||
|
||||
user_id = workflow.user_id
|
||||
|
||||
# 3. Check quota (apply per-workflow model_overrides).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=inbound_workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 4. Create workflow run
|
||||
# 3. Create workflow run
|
||||
call_id = channel_id
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=f"ARI Inbound {caller_number}",
|
||||
|
|
@ -602,6 +590,20 @@ class ARIConnection:
|
|||
f"(caller={caller_number}, called={called_number})"
|
||||
)
|
||||
|
||||
# 4. Check quota after the run exists so hosted v2 can mint and
|
||||
# store the MPS correlation id before the pipeline starts.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=inbound_workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"[ARI org={self.organization_id}] Quota exceeded for user {user_id} "
|
||||
f"— hanging up inbound call {channel_id}"
|
||||
)
|
||||
await self._delete_channel(channel_id)
|
||||
return
|
||||
|
||||
# 5. Answer the inbound channel
|
||||
await self._answer_channel(channel_id)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue