Merge remote-tracking branch 'origin/main' into feat/user-onboarding

This commit is contained in:
Abhishek Kumar 2026-06-12 18:54:48 +05:30
commit 093e888ce4
148 changed files with 10908 additions and 2815 deletions

View file

@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect
from api.db import db_client
from api.enums import CallType, WorkflowRunState
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony import registry as telephony_registry
router = APIRouter(prefix="/agent-stream")
@ -67,19 +67,6 @@ async def agent_stream_websocket(
await websocket.close(code=1008, reason="Workflow not found")
return
quota_result = await check_dograh_quota_by_user_id(
workflow.user_id, workflow_id=workflow.id
)
if not quota_result.has_quota:
logger.warning(
f"agent-stream quota exceeded for user {workflow.user_id}: "
f"{quota_result.error_message}"
)
await websocket.close(
code=1008, reason=quota_result.error_message or "Quota exceeded"
)
return
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
workflow_run_name = f"WR-AGS-{numeric_suffix:08d}"
call_id = params.get("callId") or params.get("CallSid")
@ -108,6 +95,20 @@ async def agent_stream_websocket(
set_current_run_id(workflow_run.id)
set_current_org_id(workflow.organization_id)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
logger.warning(
f"agent-stream quota exceeded for user {workflow.user_id}: "
f"{quota_result.error_message}"
)
await websocket.close(
code=1008, reason=quota_result.error_message or "Quota exceeded"
)
return
await db_client.update_workflow_run(
run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value
)

View file

@ -3,9 +3,12 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
from api.services.posthog_client import capture_event
from api.utils.auth import create_jwt_token, hash_password, verify_password
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
)
if mps_config:
await db_client.update_user_configuration(user.id, mps_config)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception:
logger.warning(
"Failed to create default configuration for OSS user", exc_info=True

View file

@ -18,7 +18,7 @@ from api.services.auth.depends import get_user
from api.services.campaign.runner import campaign_runner_service
from api.services.campaign.source_sync import CampaignSourceSyncService
from api.services.campaign.source_sync_factory import get_sync_service
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
from api.services.reports import generate_campaign_report_csv
from api.services.storage import storage_fs
@ -550,7 +550,10 @@ async def start_campaign(
# Check Dograh quota before starting campaign (apply per-workflow
# model_overrides so we evaluate the keys this campaign will use).
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
@ -872,7 +875,10 @@ async def resume_campaign(
# Check Dograh quota before resuming campaign (apply per-workflow
# model_overrides so we evaluate the keys this campaign will use).
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=campaign.workflow_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)

View file

@ -369,6 +369,10 @@ async def search_chunks(
try:
# Import here to avoid circular dependency
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
from api.services.configuration.registry import ServiceProviders
from api.services.gen_ai import (
AzureOpenAIEmbeddingService,
@ -376,20 +380,29 @@ async def search_chunks(
)
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
effective_config = resolved_config.effective
embeddings_api_key = None
embeddings_model = None
embeddings_provider = None
embeddings_base_url = None
embeddings_endpoint = None
embeddings_api_version = None
if user_config.embeddings:
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
if effective_config.embeddings:
embeddings_api_key = effective_config.embeddings.api_key
embeddings_model = effective_config.embeddings.model
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(effective_config.embeddings, "base_url", None),
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
effective_config.embeddings, "api_version", None
)
# Initialize embedding service based on provider
@ -406,9 +419,7 @@ async def search_chunks(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=getattr(user_config.embeddings, "base_url", None)
if user_config.embeddings
else None,
base_url=embeddings_base_url,
)
# Perform search

View file

@ -1,15 +1,27 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
from api.constants import (
DEFAULT_CAMPAIGN_RETRY_CONFIG,
DEFAULT_ORG_CONCURRENCY_LIMIT,
DEPLOYMENT_MODE,
)
from api.db import db_client
from api.db.models import UserModel
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
)
from api.schemas.organization_preferences import OrganizationPreferences
from api.schemas.telephony_config import (
TelephonyConfigRequest,
TelephonyConfigurationCreateRequest,
@ -26,8 +38,36 @@ from api.schemas.telephony_phone_number import (
PhoneNumberUpdateRequest,
ProviderSyncStatus,
)
from api.services.auth.depends import get_user
from api.services.configuration.masking import is_mask_of, mask_key
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.configuration.ai_model_configuration import (
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_organization_ai_model_configuration_v2,
get_resolved_ai_model_configuration,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_model_configurations_to_v2,
upsert_organization_ai_model_configuration_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
from api.services.configuration.registry import (
DOGRAH_STT_LANGUAGES,
REGISTRY,
ServiceProviders,
ServiceType,
)
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
from api.services.organization_context import (
OrganizationContextResponse,
get_organization_context,
)
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
from api.services.posthog_client import capture_event
from api.services.telephony import registry as telephony_registry
from api.services.telephony.factory import get_telephony_provider_by_id
@ -98,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
telnyx_missing_webhook_public_key_count: int
@router.get("/context", response_model=OrganizationContextResponse)
async def get_current_organization_context(user: UserModel = Depends(get_user)):
"""Return organization-scoped configuration signals owned by Dograh."""
return await get_organization_context(user)
@router.get(
"/telephony-providers/metadata",
response_model=TelephonyProvidersMetadataResponse,
@ -159,6 +205,239 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
)
# ---------------------------------------------------------------------------
# AI model configurations v2
# ---------------------------------------------------------------------------
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
return {
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[service_type].items()
if provider != ServiceProviders.DOGRAH.value
}
async def _model_configuration_v2_response(
*,
user: UserModel,
configuration: OrganizationAIModelConfigurationV2 | None = None,
) -> OrganizationAIModelConfigurationResponse:
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
raw_configuration = (
configuration
if configuration is not None
else resolved.organization_configuration
)
return OrganizationAIModelConfigurationResponse(
configuration=mask_ai_model_configuration_v2(raw_configuration),
effective_configuration=mask_user_config(resolved.effective),
source=resolved.source,
)
@router.get("/model-configurations/v2/defaults")
async def get_model_configuration_v2_defaults(
user: UserModel = Depends(get_user_with_selected_organization),
):
byok_default_providers = {
service: provider
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
if provider != ServiceProviders.DOGRAH.value
}
return {
"dograh": {
"voices": [DOGRAH_DEFAULT_VOICE],
"speeds": list(DOGRAH_SPEED_OPTIONS),
"languages": DOGRAH_STT_LANGUAGES,
"defaults": {
"voice": DOGRAH_DEFAULT_VOICE,
"speed": 1.0,
"language": DOGRAH_DEFAULT_LANGUAGE,
},
},
"byok": {
"pipeline": {
"llm": _byok_provider_schemas(ServiceType.LLM),
"tts": _byok_provider_schemas(ServiceType.TTS),
"stt": _byok_provider_schemas(ServiceType.STT),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
"realtime": {
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
"llm": _byok_provider_schemas(ServiceType.LLM),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
},
}
@router.get(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def get_model_configuration_v2(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await _model_configuration_v2_response(user=user)
@router.put(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def save_model_configuration_v2(
request: OrganizationAIModelConfigurationV2,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
existing = await get_organization_ai_model_configuration_v2(organization_id)
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
try:
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/model-configurations/v2/migration-preview")
async def preview_model_configuration_v2_migration(
user: UserModel = Depends(get_user_with_selected_organization),
):
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc))
return {
"configuration": mask_ai_model_configuration_v2(configuration),
"effective_configuration": mask_user_config(
compile_ai_model_configuration_v2(configuration)
),
}
@router.post(
"/model-configurations/v2/migrate",
response_model=OrganizationAIModelConfigurationResponse,
)
async def migrate_model_configuration_v2(
force: bool = Query(default=False),
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
existing = await get_organization_ai_model_configuration_v2(organization_id)
if existing is not None and not force:
raise HTTPException(
status_code=409,
detail="Organization already has a v2 model configuration",
)
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
if DEPLOYMENT_MODE != "oss":
try:
await ensure_hosted_mps_billing_account_v2(
organization_id,
created_by=str(user.provider_id),
)
except Exception as exc:
logger.error(
"Failed to initialize MPS billing v2 account for organization {}: {}",
organization_id,
exc,
)
raise HTTPException(
status_code=502,
detail="Failed to initialize MPS billing v2 account",
)
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
await migrate_workflow_model_configurations_to_v2(
organization_id=organization_id,
fallback_user_config=legacy,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/preferences", response_model=OrganizationPreferences)
async def get_preferences(
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await get_organization_preferences(organization_id)
@router.put("/preferences", response_model=OrganizationPreferences)
async def save_preferences(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await upsert_organization_preferences(
organization_id,
request,
)
@router.get(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def get_model_configuration_preferences_legacy(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await get_preferences(user=user)
@router.put(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def save_model_configuration_preferences_legacy(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
return await save_preferences(request=request, user=user)
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
"""If the client re-submitted a masked sensitive field, restore the original."""
for field_name in _sensitive_fields(provider):

View file

@ -1,16 +1,16 @@
import json
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from loguru import logger
from pydantic import BaseModel, Field
from api.constants import DEPLOYMENT_MODE
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.mps_service_key_client import mps_service_key_client
from api.services.reports import generate_usage_runs_report_csv
from api.utils.artifacts import artifact_url
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
period_start: str
period_end: str
used_dograh_tokens: float
quota_dograh_tokens: int
percentage_used: float
next_refresh_date: str
quota_enabled: bool
total_duration_seconds: int
# New USD fields
used_amount_usd: Optional[float] = None
quota_amount_usd: Optional[float] = None
currency: Optional[str] = None
price_per_second_usd: Optional[float] = None
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
total_quota: float
class MPSCreditPurchaseUrlResponse(BaseModel):
checkout_url: str
class MPSBillingAccountResponse(BaseModel):
id: int
organization_id: int
billing_mode: str
cached_balance_credits: float
currency: str
class MPSCreditLedgerEntryResponse(BaseModel):
id: int
entry_type: str
origin: Optional[str] = None
credits_delta: float
balance_after: float
amount_minor: Optional[int] = None
amount_currency: Optional[str] = None
payment_order_id: Optional[int] = None
metric_code: Optional[str] = None
correlation_id: Optional[str] = None
aggregation_key: Optional[str] = None
usage_event_id: Optional[int] = None
workflow_run_id: Optional[int] = None
workflow_id: Optional[int] = None
billable_quantity: Optional[float] = None
quantity_unit: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: str
class MPSBillingCreditsResponse(BaseModel):
billing_version: Literal["legacy", "v2"]
total_credits_used: float = 0.0
remaining_credits: float = 0.0
total_quota: float = 0.0
account: Optional[MPSBillingAccountResponse] = None
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
total_count: int = 0
page: int = 1
limit: int = 50
total_pages: int = 0
def _optional_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
class WorkflowRunUsageResponse(BaseModel):
id: int
workflow_id: int
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
async def get_current_period_usage(user: UserModel = Depends(get_user)):
"""Get current billing period usage for the user's organization."""
"""Get current reporting-period usage for the user's organization."""
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
raise HTTPException(status_code=500, detail=str(e))
async def _get_mps_billing_account_status(
user: UserModel, organization_id: int
) -> Optional[dict]:
return await mps_service_key_client.get_billing_account_status(
organization_id=organization_id,
created_by=str(user.provider_id),
)
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
return bool(account and account.get("billing_mode") == "v2")
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
if DEPLOYMENT_MODE == "oss":
usage = await mps_service_key_client.get_usage_by_created_by(
str(user.provider_id)
)
else:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
usage = await mps_service_key_client.get_usage_by_organization(
user.selected_organization_id
)
total_used = float(usage.get("total_credits_used", 0.0))
total_remaining = float(usage.get("remaining_credits", 0.0))
return MPSBillingCreditsResponse(
billing_version="legacy",
total_credits_used=total_used,
remaining_credits=total_remaining,
total_quota=total_used + total_remaining,
)
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
async def get_billing_credits(
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
user: UserModel = Depends(get_user),
):
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
try:
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
return await _legacy_mps_credits_response(user)
organization_id = user.selected_organization_id
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
return await _legacy_mps_credits_response(user)
ledger = await mps_service_key_client.get_credit_ledger(
organization_id=organization_id,
page=page,
limit=limit,
created_by=str(user.provider_id),
)
account = ledger.get("account") or {}
ledger_entries = ledger.get("ledger_entries") or []
total_count = int(ledger.get("total_count") or len(ledger_entries))
response_limit = int(ledger.get("limit") or limit)
total_pages = int(
ledger.get("total_pages")
or ((total_count + response_limit - 1) // response_limit)
)
workflow_ids_by_run_id: dict[int, int] = {}
workflow_run_ids = {
workflow_run_id
for entry in ledger_entries
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
is not None
}
for workflow_run_id in workflow_run_ids:
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
if (
workflow_run
and workflow_run.workflow
and workflow_run.workflow.organization_id == organization_id
):
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
balance = float(account.get("cached_balance_credits") or 0.0)
total_debits = sum(
abs(float(entry.get("credits_delta") or 0.0))
for entry in ledger_entries
if float(entry.get("credits_delta") or 0.0) < 0
)
if ledger.get("total_debits_credits") is not None:
total_debits = float(ledger["total_debits_credits"])
return MPSBillingCreditsResponse(
billing_version="v2",
total_credits_used=total_debits,
remaining_credits=balance,
total_quota=balance + total_debits,
account=MPSBillingAccountResponse(
id=int(account["id"]),
organization_id=int(account["organization_id"]),
billing_mode=str(account["billing_mode"]),
cached_balance_credits=balance,
currency=str(account.get("currency") or "USD"),
),
ledger_entries=[
MPSCreditLedgerEntryResponse(
id=int(entry["id"]),
entry_type=str(entry["entry_type"]),
origin=entry.get("origin"),
credits_delta=float(entry.get("credits_delta") or 0.0),
balance_after=float(entry.get("balance_after") or 0.0),
amount_minor=entry.get("amount_minor"),
amount_currency=entry.get("amount_currency"),
payment_order_id=entry.get("payment_order_id"),
metric_code=entry.get("metric_code"),
correlation_id=entry.get("correlation_id"),
aggregation_key=entry.get("aggregation_key"),
usage_event_id=_optional_int(entry.get("usage_event_id")),
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
workflow_id=workflow_ids_by_run_id.get(
_optional_int(entry.get("workflow_run_id"))
)
if entry.get("workflow_run_id") is not None
else None,
billable_quantity=float(entry["billable_quantity"])
if entry.get("billable_quantity") is not None
else None,
quantity_unit=entry.get("quantity_unit"),
metadata=entry.get("metadata") or {},
created_at=str(entry["created_at"]),
)
for entry in ledger_entries
],
total_count=total_count,
page=int(ledger.get("page") or page),
limit=response_limit,
total_pages=total_pages,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Failed to fetch billing credits: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post(
"/usage/mps-credits/purchase-url",
response_model=MPSCreditPurchaseUrlResponse,
)
async def create_mps_credit_purchase_url(
user: UserModel = Depends(get_user_with_selected_organization),
):
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
if DEPLOYMENT_MODE == "oss":
raise HTTPException(
status_code=404,
detail="Credit purchases are not available in OSS mode",
)
organization_id = user.selected_organization_id
assert organization_id is not None
account_status = await _get_mps_billing_account_status(user, organization_id)
if not _is_mps_billing_v2(account_status):
raise HTTPException(
status_code=403,
detail=(
"Credit purchases are available only for organizations using billing v2"
),
)
try:
session = await mps_service_key_client.create_credit_purchase_url(
organization_id=organization_id,
created_by=str(user.provider_id),
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
billing_details={
"source": "dograh_billing",
"dograh_user_id": str(user.id),
"dograh_provider_id": str(user.provider_id),
},
)
except Exception as exc:
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
raise HTTPException(
status_code=502,
detail="Failed to create credit purchase URL",
)
checkout_url = session.get("checkout_url")
if not checkout_url:
logger.error(f"MPS checkout session response missing checkout_url: {session}")
raise HTTPException(
status_code=502,
detail="MPS checkout session response missing checkout_url",
)
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
FILTERS_DESCRIPTION = """\
JSON-encoded array of filter objects. Each object has the shape:

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel
from api.db import db_client
from api.enums import TriggerState, WorkflowStatus
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.factory import (
get_default_telephony_provider,
get_telephony_provider_by_id,
@ -179,14 +179,6 @@ async def _execute_resolved_target(
"""Shared execution path once the target workflow has been resolved."""
execution_user_id = _get_execution_user_id(target.workflow)
# Check Dograh quota using the workflow owner's config and model overrides.
quota_result = await check_dograh_quota_by_user_id(
execution_user_id,
workflow_id=target.workflow.id,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Get telephony provider — either the caller-specified config (validated
# against the workflow's org) or the org's default config.
if request.telephony_configuration_id is not None:
@ -268,6 +260,15 @@ async def _execute_resolved_target(
f"to phone number {request.phone_number}"
)
# Check Dograh quota after the run exists so hosted v2 can mint and store
# the MPS correlation id before the provider starts the call.
quota_result = await authorize_workflow_run_start(
workflow_id=target.workflow.id,
workflow_run_id=workflow_run.id,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# 9. Construct webhook URL for telephony provider callback
backend_endpoint, _ = await get_backend_endpoints()
webhook_endpoint = provider.WEBHOOK_ENDPOINT

View file

@ -7,6 +7,7 @@ They handle CORS, domain validation, and session management for embedded workflo
import secrets
from datetime import UTC, datetime, timedelta
from typing import Optional
from urllib.parse import urlsplit
from fastapi import (
APIRouter,
@ -16,6 +17,8 @@ from fastapi import (
)
from loguru import logger
from pydantic import BaseModel
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send
from api.db import db_client
from api.enums import WorkflowRunMode
@ -27,6 +30,9 @@ from api.routes.turn_credentials import (
router = APIRouter(prefix="/public/embed")
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
EMBED_CORS_MAX_AGE = "86400"
class InitEmbedRequest(BaseModel):
"""Request model for initializing an embed session"""
@ -70,11 +76,9 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
# If no domains specified, allow all origins
return True
# Extract domain from origin (remove protocol)
if "://" in origin:
domain = origin.split("://")[1].split("/")[0].split(":")[0]
else:
domain = origin
domain, origin_port = _parse_origin_host_port(origin)
if not domain:
return False
# Normalize domain for www matching
def normalize_www(d: str) -> tuple[str, str]:
@ -87,16 +91,23 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
domain_variants = normalize_www(domain)
for allowed in allowed_domains:
allowed = str(allowed).strip().lower()
if allowed == "*":
return True
elif allowed.startswith("*."):
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
if not allowed_domain:
continue
if allowed_port is not None and allowed_port != origin_port:
continue
if allowed_domain.startswith("*."):
# Wildcard subdomain matching
base_domain = allowed[2:]
base_domain = allowed_domain[2:]
if domain == base_domain or domain.endswith("." + base_domain):
return True
else:
# Check both www and non-www versions
allowed_variants = normalize_www(allowed)
allowed_variants = normalize_www(allowed_domain)
# If any variant of domain matches any variant of allowed, it's valid
if any(
dv in allowed_variants or av in domain_variants
@ -108,6 +119,24 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
return False
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
candidate = value.strip().lower()
if not candidate:
return "", None
if "://" not in candidate and not candidate.startswith("//"):
candidate = f"//{candidate}"
parsed = urlsplit(candidate)
try:
parsed_port = parsed.port
except ValueError:
parsed_port = None
port = str(parsed_port) if parsed_port is not None else None
return (parsed.hostname or "").rstrip("."), port
def generate_session_token() -> str:
"""Generate a cryptographically secure session token"""
return f"emb_session_{secrets.token_urlsafe(32)}"
@ -121,8 +150,120 @@ def get_request_origin(request: Request) -> str:
return origin
def _cors_response(origin: str, methods: str) -> Response:
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": methods,
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
"Vary": "Origin",
}
)
def _allow_embed_origin(response: Response, origin: str) -> None:
response.headers["Access-Control-Allow-Origin"] = origin
vary = response.headers.get("Vary")
if not vary:
response.headers["Vary"] = "Origin"
return
vary_values = {value.strip().lower() for value in vary.split(",")}
if "origin" not in vary_values:
response.headers["Vary"] = f"{vary}, Origin"
async def _config_preflight_response(token: str, origin: str) -> Response:
embed_token = await db_client.get_embed_token_by_token(token)
if not embed_token or not embed_token.is_active:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def _turn_credentials_preflight_response(
session_token: str, origin: str
) -> Response:
embed_session = await db_client.get_embed_session_by_token(session_token)
if not embed_session:
return Response(status_code=403)
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
return Response(status_code=403)
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
if not embed_token:
return Response(status_code=403)
if not validate_origin(origin, embed_token.allowed_domains or []):
return Response(status_code=403)
return _cors_response(origin, "GET, OPTIONS")
async def build_public_embed_preflight_response(
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
) -> Response | None:
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
if path == f"{public_embed_prefix}/init":
if requested_method.upper() != "POST":
return Response(status_code=405)
return _cors_response(origin, "POST, OPTIONS")
config_prefix = f"{public_embed_prefix}/config/"
if path.startswith(config_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
token = path[len(config_prefix) :].split("/", 1)[0]
return await _config_preflight_response(token, origin)
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
if path.startswith(turn_credentials_prefix):
if requested_method.upper() != "GET":
return Response(status_code=405)
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
return await _turn_credentials_preflight_response(session_token, origin)
return None
class PublicEmbedCORSMiddleware:
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
self.app = app
self.api_prefix = api_prefix
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin")
requested_method = headers.get("access-control-request-method")
if origin and requested_method:
response = await build_public_embed_preflight_response(
scope.get("path", ""), origin, requested_method, self.api_prefix
)
if response is not None:
await response(scope, receive, send)
return
await self.app(scope, receive, send)
@router.post("/init", response_model=InitEmbedResponse)
async def initialize_embed_session(request: Request, init_request: InitEmbedRequest):
async def initialize_embed_session(
request: Request, init_request: InitEmbedRequest, response: Response
):
"""Initialize an embed session with token validation and domain checking.
This endpoint:
@ -158,6 +299,9 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Create workflow run
try:
workflow_run = await db_client.create_workflow_run(
@ -204,8 +348,19 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
)
@router.options("/config/{token}")
async def options_embed_config(token: str, request: Request):
"""Fallback OPTIONS handler for the embed config endpoint.
Browser preflights include Access-Control-Request-Method and are handled by
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
OPTIONS requests on the same validation path.
"""
return await _config_preflight_response(token, request.headers.get("origin", ""))
@router.get("/config/{token}", response_model=EmbedConfigResponse)
async def get_embed_config(token: str, request: Request):
async def get_embed_config(token: str, request: Request, response: Response):
"""Get embed configuration without creating a session.
This endpoint is used to fetch widget configuration for display purposes
@ -226,6 +381,11 @@ async def get_embed_config(token: str, request: Request):
if not validate_origin(origin, embed_token.allowed_domains or []):
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
# Set CORS header explicitly; the global CORSMiddleware covers only
# first-party origins; this endpoint is fetched by external embed sites.
if origin:
_allow_embed_origin(response, origin)
# Extract settings with defaults
settings = embed_token.settings or {}
@ -243,24 +403,20 @@ async def get_embed_config(token: str, request: Request):
@router.options("/init")
async def options_init(request: Request):
"""Handle CORS preflight for init endpoint"""
"""Fallback OPTIONS handler for init endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
# For init endpoint, we need to check the token in the request body
# But OPTIONS requests don't have body, so we'll be permissive
# The actual validation happens in the POST request
origin = request.headers.get("origin", "*")
return Response(
headers={
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Origin",
"Access-Control-Max-Age": "86400",
}
)
return _cors_response(origin, "POST, OPTIONS")
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
async def get_public_turn_credentials(session_token: str, request: Request):
async def get_public_turn_credentials(
session_token: str, request: Request, response: Response
):
"""Get TURN credentials for an embed session.
This endpoint allows embedded widgets to obtain TURN server credentials
@ -295,6 +451,9 @@ async def get_public_turn_credentials(session_token: str, request: Request):
)
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
if origin:
_allow_embed_origin(response, origin)
# Check if TURN is configured
if not TURN_SECRET:
raise HTTPException(
@ -316,63 +475,8 @@ async def get_public_turn_credentials(session_token: str, request: Request):
@router.options("/turn-credentials/{session_token}")
async def options_turn_credentials(request: Request, session_token: str):
"""Handle CORS preflight for TURN credentials endpoint"""
origin = request.headers.get("origin", "*")
# Try to validate the session token and get allowed domains
allowed_origin = origin
try:
embed_session = await db_client.get_embed_session_by_token(session_token)
if embed_session:
embed_token = await db_client.get_embed_token_by_id(
embed_session.embed_token_id
)
if embed_token:
# Check if origin is in allowed domains (empty means allow all)
if validate_origin(origin, embed_token.allowed_domains or []):
allowed_origin = origin
else:
allowed_origin = ""
except Exception:
# On error, be permissive for OPTIONS
pass
return Response(
headers={
"Access-Control-Allow-Origin": allowed_origin,
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "86400",
}
)
@router.options("/config/{token}")
async def options_config(request: Request, token: str):
"""Handle CORS preflight for config endpoint"""
# Get origin header
origin = request.headers.get("origin", "*")
# Try to validate the token and get allowed domains
allowed_origin = origin
try:
embed_token = await db_client.get_embed_token_by_token(token)
if embed_token and embed_token.is_active:
# Check if origin is in allowed domains
if validate_origin(origin, embed_token.allowed_domains or []):
allowed_origin = origin
else:
# If not allowed, don't include the origin
allowed_origin = ""
except Exception:
# On error, be permissive for OPTIONS
pass
return Response(
headers={
"Access-Control-Allow-Origin": allowed_origin,
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "86400",
}
"""Fallback OPTIONS handler for TURN credentials endpoint."""
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
return await _turn_credentials_preflight_response(
session_token, request.headers.get("origin", "")
)

View file

@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
from api.errors.telephony_errors import TelephonyError
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.quota_service import authorize_workflow_run_start
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
from api.services.telephony.factory import (
get_all_telephony_providers,
@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
workflow_run_id: int | None = None
phone_number: str | None = None
# Optional explicit telephony config to use for the test call. If omitted,
# falls back to the user's per-user default (when set), then the org default.
# falls back to the org default.
telephony_configuration_id: int | None = None
# Optional caller-ID phone number to dial out from. Must belong to the
# resolved telephony configuration; otherwise the provider picks one.
@ -82,7 +82,12 @@ async def initiate_call(
"""Initiate a call using the configured telephony provider from web browser. This is
supposed to be a test call method for the draft version of the agent."""
user_configuration = await db_client.get_user_configurations(user.id)
from api.services.organization_preferences import get_organization_preferences
preferences = await get_organization_preferences(
user.selected_organization_id,
db=db_client,
)
# Resolve which telephony config to use: explicit request value, otherwise
# the org's default outbound config.
@ -116,13 +121,12 @@ async def initiate_call(
detail="telephony_not_configured",
)
phone_number = request.phone_number or user_configuration.test_phone_number
phone_number = request.phone_number or preferences.test_phone_number
if not phone_number:
raise HTTPException(
status_code=400,
detail="Phone number must be provided in request or set in user "
"configuration",
detail="Phone number must be provided in request or set in organization preferences",
)
workflow = await db_client.get_workflow(
@ -132,14 +136,6 @@ async def initiate_call(
raise HTTPException(status_code=404, detail="Workflow not found")
execution_user_id = _get_execution_user_id(workflow)
# Check Dograh quota before initiating the call (apply per-workflow
# model_overrides so the keys we will actually use are the ones checked).
quota_result = await check_dograh_quota_by_user_id(
execution_user_id, workflow_id=workflow.id
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Determine the workflow run mode based on provider type
workflow_run_mode = provider.PROVIDER_NAME
@ -182,6 +178,16 @@ async def initiate_call(
)
workflow_run_name = workflow_run.name
# Check Dograh quota after the run exists so hosted v2 can mint and store
# the MPS correlation id before initiating the call.
quota_result = await authorize_workflow_run_start(
workflow_id=workflow.id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
# Construct webhook URL based on provider type
backend_endpoint, _ = await get_backend_endpoints()
@ -735,19 +741,8 @@ async def handle_inbound_run(request: Request):
TelephonyError.SIGNATURE_VALIDATION_FAILED
)
# 4. Quota check (use the workflow's model_overrides if set).
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# 5. Create workflow run + return provider-shaped response.
# 5. Create workflow run + authorize quota before returning provider
# stream instructions.
workflow_run_id = await _create_inbound_workflow_run(
workflow_id,
user_id,
@ -756,6 +751,17 @@ async def handle_inbound_run(request: Request):
telephony_configuration_id=telephony_configuration_id,
from_phone_number_id=phone_row.id,
)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
websocket_url = (
@ -870,20 +876,8 @@ async def handle_inbound_telephony(
logger.error(f"Request validation failed: {error_type}")
return provider_class.generate_validation_error_response(error_type)
# Check quota before processing (apply per-workflow model_overrides).
# Create workflow run.
user_id = workflow_context["user_id"]
quota_result = await check_dograh_quota_by_user_id(
user_id, workflow_id=workflow_id
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# Create workflow run
workflow_run_id = await _create_inbound_workflow_run(
workflow_id,
workflow_context["user_id"],
@ -892,6 +886,17 @@ async def handle_inbound_telephony(
telephony_configuration_id=workflow_context["telephony_configuration_id"],
from_phone_number_id=workflow_context.get("from_phone_number_id"),
)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
)
if not quota_result.has_quota:
logger.warning(
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
)
return provider_class.generate_validation_error_response(
TelephonyError.QUOTA_EXCEEDED
)
# Generate response URLs
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()

View file

@ -10,6 +10,9 @@ from api.db.models import (
UserModel,
)
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
from api.services.configuration.check_validity import (
APIKeyStatusResponse,
UserConfigurationValidator,
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
from api.services.configuration.merge import merge_user_configurations
from api.services.configuration.registry import REGISTRY, ServiceType
from api.services.mps_service_key_client import mps_service_key_client
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
router = APIRouter(prefix="/user")
@ -94,8 +101,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
async def get_user_configurations(
user: UserModel = Depends(get_user),
) -> UserConfigurationRequestResponseSchema:
user_configurations = await db_client.get_user_configurations(user.id)
masked_config = mask_user_config(user_configurations)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
masked_config = mask_user_config(resolved_config.effective)
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -121,34 +137,61 @@ async def update_user_configurations(
# Remove organization_pricing from incoming dict as it's read-only
incoming_dict.pop("organization_pricing", None)
preferences_update = {
key: incoming_dict.pop(key)
for key in ("test_phone_number", "timezone")
if key in incoming_dict
}
# Merge via helper
try:
user_configurations = merge_user_configurations(existing_config, incoming_dict)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
if incoming_dict:
# Merge via helper
try:
user_configurations = merge_user_configurations(
existing_config, incoming_dict
)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
else:
user_configurations = existing_config
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
if user.selected_organization_id and preferences_update:
preferences = await get_organization_preferences(user.selected_organization_id)
if "test_phone_number" in preferences_update:
preferences.test_phone_number = preferences_update["test_phone_number"]
if "timezone" in preferences_update:
preferences.timezone = preferences_update["timezone"]
await upsert_organization_preferences(
user.selected_organization_id,
preferences,
)
# Return masked version of updated config
masked_config = mask_user_config(user_configurations)
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -168,7 +211,11 @@ async def validate_user_configurations(
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
user: UserModel = Depends(get_user),
) -> APIKeyStatusResponse:
configurations = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
configurations = resolved_config.effective
if (
configurations.last_validated_at

View file

@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import (
register_ws_sender,
unregister_ws_sender,
)
from api.services.quota_service import check_dograh_quota
from api.services.quota_service import authorize_workflow_run_start
router = APIRouter(prefix="/ws")
@ -329,7 +329,11 @@ class SignalingManager:
# Check Dograh quota before initiating the call (apply per-workflow
# model_overrides so we evaluate the keys this workflow will use).
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
# Send error response for quota issues
await ws.send_json(

View file

@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
from api.db.models import UserModel
from api.db.workflow_template_client import WorkflowTemplateClient
from api.enums import CallType, PostHogEvent, StorageBackend
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
from api.schemas.workflow import WorkflowRunResponseSchema
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_resolved_ai_model_configuration,
merge_ai_model_configuration_v2_secrets,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import (
mask_workflow_configurations,
@ -32,12 +41,15 @@ from api.services.configuration.resolve import (
)
from api.services.mps_service_key_client import mps_service_key_client
from api.services.posthog_client import capture_event
from api.services.pricing.run_usage_response import format_public_usage_info
from api.services.reports import generate_workflow_report_csv
from api.services.storage import storage_fs
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
from api.services.workflow.duplicate import duplicate_workflow
from api.services.workflow.errors import ItemKind, WorkflowError
from api.services.workflow.run_usage_response import (
format_public_cost_info,
format_public_usage_info,
)
from api.services.workflow.trigger_paths import (
TriggerPathIssue,
ensure_trigger_paths,
@ -955,12 +967,74 @@ async def update_workflow(
existing_def,
)
# Validate model_overrides: resolve onto global config, then
# run the same validator used by the user-configurations endpoint.
# Also stamp the current global API key into the override so the override
# remains functional if the global config later switches to a different provider.
# Validate model overrides. v2 uses a complete workflow-level model
# configuration; legacy v1 uses partial service overlays.
workflow_configurations = request.workflow_configurations
if workflow_configurations and workflow_configurations.get("model_overrides"):
if workflow_configurations and workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
if existing_workflow is None:
raise HTTPException(
status_code=404, detail=f"Workflow with id {workflow_id} not found"
)
existing_draft = await db_client.get_draft_version(workflow_id)
existing_configs = (
existing_draft.workflow_configurations
if existing_draft
else existing_workflow.released_definition.workflow_configurations
)
existing_v2_override = (existing_configs or {}).get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
try:
incoming_v2_override = (
OrganizationAIModelConfigurationV2.model_validate(
workflow_configurations[
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
]
)
)
existing_v2_override_config = (
OrganizationAIModelConfigurationV2.model_validate(
existing_v2_override
)
if existing_v2_override
else None
)
v2_override = merge_ai_model_configuration_v2_secrets(
incoming_v2_override,
existing_v2_override_config,
)
if existing_v2_override_config is None:
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
v2_override = merge_ai_model_configuration_v2_secrets(
v2_override,
resolved_config.organization_configuration,
)
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
effective = compile_ai_model_configuration_v2(v2_override)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except (ValidationError, ValueError) as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
elif workflow_configurations and workflow_configurations.get("model_overrides"):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
@ -978,24 +1052,48 @@ async def update_workflow(
workflow_configurations,
existing_configs,
)
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
effective_config = resolved_config.effective
try:
enriched_overrides = enrich_overrides_with_api_keys(
workflow_configurations["model_overrides"],
user_config,
effective_config,
)
effective = resolve_effective_config(user_config, enriched_overrides)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
effective = resolve_effective_config(
effective_config, enriched_overrides
)
if resolved_config.source == "organization_v2":
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
await UserConfigurationValidator().validate(
compile_ai_model_configuration_v2(v2_override),
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
else:
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
if resolved_config.source == "organization_v2":
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
else:
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
# Reject upfront if any new trigger path collides with another
# workflow's trigger — keeps the workflow record from
@ -1171,22 +1269,7 @@ async def get_workflow_run(
"transcript_public_url": artifact_url(public_access_token, "transcript"),
"recording_public_url": artifact_url(public_access_token, "recording"),
"public_access_token": public_access_token,
"cost_info": {
"dograh_token_usage": (
run.cost_info.get("dograh_token_usage")
if run.cost_info and "dograh_token_usage" in run.cost_info
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
if run.cost_info and "total_cost_usd" in run.cost_info
else 0
),
"call_duration_seconds": int(
round(run.cost_info.get("call_duration_seconds"))
)
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
else None,
}
if run.cost_info
else None,
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
"usage_info": format_public_usage_info(run.usage_info),
"created_at": run.created_at,
"definition_id": run.definition_id,

View file

@ -9,8 +9,8 @@ from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.quota_service import check_dograh_quota
from api.services.auth.depends import get_user_with_selected_organization
from api.services.quota_service import authorize_workflow_run_start
from api.services.workflow.text_chat_session_service import (
TextChatPendingTurnLostError,
TextChatSessionExecutionError,
@ -96,14 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
}
def _require_selected_organization_id(user: UserModel) -> int:
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
return user.selected_organization_id
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
async def _ensure_text_chat_quota(
user: UserModel,
workflow_id: int,
workflow_run_id: int,
) -> None:
quota_result = await authorize_workflow_run_start(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
actor_user=user,
)
if not quota_result.has_quota:
raise HTTPException(status_code=402, detail=quota_result.error_message)
@ -114,9 +116,8 @@ async def _load_text_session_or_404(
user: UserModel,
) -> WorkflowRunTextSessionModel:
set_current_run_id(run_id)
organization_id = _require_selected_organization_id(user)
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=organization_id
run_id, organization_id=user.selected_organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
@ -158,11 +159,8 @@ async def _execute_pending_turn_response(
async def create_text_chat_session(
workflow_id: int,
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
organization_id = _require_selected_organization_id(user)
await _ensure_text_chat_quota(user, workflow_id)
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
try:
workflow_run = await db_client.create_workflow_run(
@ -172,12 +170,13 @@ async def create_text_chat_session(
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
organization_id=organization_id,
organization_id=user.selected_organization_id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
set_current_run_id(workflow_run.id)
await _ensure_text_chat_quota(user, workflow_id, workflow_run.id)
annotations = {
"tester": {
@ -220,7 +219,7 @@ async def create_text_chat_session(
async def get_text_chat_session(
workflow_id: int,
run_id: int,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@ -234,10 +233,10 @@ async def append_text_chat_message(
workflow_id: int,
run_id: int,
request: AppendTextChatMessageRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
await _ensure_text_chat_quota(user, workflow_id)
await _ensure_text_chat_quota(user, workflow_id, run_id)
try:
text_session = await append_text_chat_user_message(
@ -264,7 +263,7 @@ async def rewind_text_chat_session(
workflow_id: int,
run_id: int,
request: RewindTextChatSessionRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
try: