mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
Merge remote-tracking branch 'origin/main' into feat/user-onboarding
This commit is contained in:
commit
093e888ce4
148 changed files with 10908 additions and 2815 deletions
|
|
@ -22,7 +22,7 @@ from starlette.websockets import WebSocketDisconnect
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import CallType, WorkflowRunState
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
|
||||
router = APIRouter(prefix="/agent-stream")
|
||||
|
|
@ -67,19 +67,6 @@ async def agent_stream_websocket(
|
|||
await websocket.close(code=1008, reason="Workflow not found")
|
||||
return
|
||||
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
workflow.user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"agent-stream quota exceeded for user {workflow.user_id}: "
|
||||
f"{quota_result.error_message}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=1008, reason=quota_result.error_message or "Quota exceeded"
|
||||
)
|
||||
return
|
||||
|
||||
numeric_suffix = int(str(uuid.uuid4()).replace("-", "")[:8], 16) % 100000000
|
||||
workflow_run_name = f"WR-AGS-{numeric_suffix:08d}"
|
||||
call_id = params.get("callId") or params.get("CallSid")
|
||||
|
|
@ -108,6 +95,20 @@ async def agent_stream_websocket(
|
|||
set_current_run_id(workflow_run.id)
|
||||
set_current_org_id(workflow.organization_id)
|
||||
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"agent-stream quota exceeded for user {workflow.user_id}: "
|
||||
f"{quota_result.error_message}"
|
||||
)
|
||||
await websocket.close(
|
||||
code=1008, reason=quota_result.error_message or "Quota exceeded"
|
||||
)
|
||||
return
|
||||
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id, state=WorkflowRunState.RUNNING.value
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
|
||||
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import create_jwt_token, hash_password, verify_password
|
||||
|
||||
|
|
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
|
|||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(user.id, mps_config)
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create default configuration for OSS user", exc_info=True
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from api.services.auth.depends import get_user
|
|||
from api.services.campaign.runner import campaign_runner_service
|
||||
from api.services.campaign.source_sync import CampaignSourceSyncService
|
||||
from api.services.campaign.source_sync_factory import get_sync_service
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.reports import generate_campaign_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
|
@ -550,7 +550,10 @@ async def start_campaign(
|
|||
|
||||
# Check Dograh quota before starting campaign (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this campaign will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
@ -872,7 +875,10 @@ async def resume_campaign(
|
|||
|
||||
# Check Dograh quota before resuming campaign (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this campaign will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=campaign.workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=campaign.workflow_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -369,6 +369,10 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import (
|
||||
AzureOpenAIEmbeddingService,
|
||||
|
|
@ -376,20 +380,29 @@ async def search_chunks(
|
|||
)
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
effective_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
if user_config.embeddings:
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
if effective_config.embeddings:
|
||||
embeddings_api_key = effective_config.embeddings.api_key
|
||||
embeddings_model = effective_config.embeddings.model
|
||||
embeddings_provider = getattr(effective_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(effective_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(effective_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
effective_config.embeddings, "api_version", None
|
||||
)
|
||||
|
||||
# Initialize embedding service based on provider
|
||||
|
|
@ -406,9 +419,7 @@ async def search_chunks(
|
|||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
|
|
|
|||
|
|
@ -1,15 +1,27 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from api.constants import DEFAULT_CAMPAIGN_RETRY_CONFIG, DEFAULT_ORG_CONCURRENCY_LIMIT
|
||||
from api.constants import (
|
||||
DEFAULT_CAMPAIGN_RETRY_CONFIG,
|
||||
DEFAULT_ORG_CONCURRENCY_LIMIT,
|
||||
DEPLOYMENT_MODE,
|
||||
)
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
)
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
from api.schemas.telephony_config import (
|
||||
TelephonyConfigRequest,
|
||||
TelephonyConfigurationCreateRequest,
|
||||
|
|
@ -26,8 +38,36 @@ from api.schemas.telephony_phone_number import (
|
|||
PhoneNumberUpdateRequest,
|
||||
ProviderSyncStatus,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import is_mask_of, mask_key
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_organization_ai_model_configuration_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_model_configurations_to_v2,
|
||||
upsert_organization_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.registry import (
|
||||
DOGRAH_STT_LANGUAGES,
|
||||
REGISTRY,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.mps_billing import ensure_hosted_mps_billing_account_v2
|
||||
from api.services.organization_context import (
|
||||
OrganizationContextResponse,
|
||||
get_organization_context,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
from api.services.telephony.factory import get_telephony_provider_by_id
|
||||
|
|
@ -98,6 +138,12 @@ class TelephonyConfigWarningsResponse(BaseModel):
|
|||
telnyx_missing_webhook_public_key_count: int
|
||||
|
||||
|
||||
@router.get("/context", response_model=OrganizationContextResponse)
|
||||
async def get_current_organization_context(user: UserModel = Depends(get_user)):
|
||||
"""Return organization-scoped configuration signals owned by Dograh."""
|
||||
return await get_organization_context(user)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/telephony-providers/metadata",
|
||||
response_model=TelephonyProvidersMetadataResponse,
|
||||
|
|
@ -159,6 +205,239 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI model configurations v2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[service_type].items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
|
||||
|
||||
async def _model_configuration_v2_response(
|
||||
*,
|
||||
user: UserModel,
|
||||
configuration: OrganizationAIModelConfigurationV2 | None = None,
|
||||
) -> OrganizationAIModelConfigurationResponse:
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
raw_configuration = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else resolved.organization_configuration
|
||||
)
|
||||
return OrganizationAIModelConfigurationResponse(
|
||||
configuration=mask_ai_model_configuration_v2(raw_configuration),
|
||||
effective_configuration=mask_user_config(resolved.effective),
|
||||
source=resolved.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/defaults")
|
||||
async def get_model_configuration_v2_defaults(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
byok_default_providers = {
|
||||
service: provider
|
||||
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
return {
|
||||
"dograh": {
|
||||
"voices": [DOGRAH_DEFAULT_VOICE],
|
||||
"speeds": list(DOGRAH_SPEED_OPTIONS),
|
||||
"languages": DOGRAH_STT_LANGUAGES,
|
||||
"defaults": {
|
||||
"voice": DOGRAH_DEFAULT_VOICE,
|
||||
"speed": 1.0,
|
||||
"language": DOGRAH_DEFAULT_LANGUAGE,
|
||||
},
|
||||
},
|
||||
"byok": {
|
||||
"pipeline": {
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"tts": _byok_provider_schemas(ServiceType.TTS),
|
||||
"stt": _byok_provider_schemas(ServiceType.STT),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
"realtime": {
|
||||
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def get_model_configuration_v2(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await _model_configuration_v2_response(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def save_model_configuration_v2(
|
||||
request: OrganizationAIModelConfigurationV2,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
|
||||
try:
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/migration-preview")
|
||||
async def preview_model_configuration_v2_migration(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
return {
|
||||
"configuration": mask_ai_model_configuration_v2(configuration),
|
||||
"effective_configuration": mask_user_config(
|
||||
compile_ai_model_configuration_v2(configuration)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-configurations/v2/migrate",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def migrate_model_configuration_v2(
|
||||
force: bool = Query(default=False),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
if existing is not None and not force:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Organization already has a v2 model configuration",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
if DEPLOYMENT_MODE != "oss":
|
||||
try:
|
||||
await ensure_hosted_mps_billing_account_v2(
|
||||
organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize MPS billing v2 account for organization {}: {}",
|
||||
organization_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to initialize MPS billing v2 account",
|
||||
)
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
await migrate_workflow_model_configurations_to_v2(
|
||||
organization_id=organization_id,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/preferences", response_model=OrganizationPreferences)
|
||||
async def get_preferences(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await get_organization_preferences(organization_id)
|
||||
|
||||
|
||||
@router.put("/preferences", response_model=OrganizationPreferences)
|
||||
async def save_preferences(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await upsert_organization_preferences(
|
||||
organization_id,
|
||||
request,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_model_configuration_preferences_legacy(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await get_preferences(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def save_model_configuration_preferences_legacy(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await save_preferences(request=request, user=user)
|
||||
|
||||
|
||||
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
|
||||
"""If the client re-submitted a masked sensitive field, restore the original."""
|
||||
for field_name in _sensitive_fields(provider):
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.constants import DEPLOYMENT_MODE, UI_APP_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.reports import generate_usage_runs_report_csv
|
||||
from api.utils.artifacts import artifact_url
|
||||
|
|
@ -22,14 +22,8 @@ class CurrentUsageResponse(BaseModel):
|
|||
period_start: str
|
||||
period_end: str
|
||||
used_dograh_tokens: float
|
||||
quota_dograh_tokens: int
|
||||
percentage_used: float
|
||||
next_refresh_date: str
|
||||
quota_enabled: bool
|
||||
total_duration_seconds: int
|
||||
# New USD fields
|
||||
used_amount_usd: Optional[float] = None
|
||||
quota_amount_usd: Optional[float] = None
|
||||
currency: Optional[str] = None
|
||||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
|
@ -40,6 +34,61 @@ class MPSCreditsResponse(BaseModel):
|
|||
total_quota: float
|
||||
|
||||
|
||||
class MPSCreditPurchaseUrlResponse(BaseModel):
|
||||
checkout_url: str
|
||||
|
||||
|
||||
class MPSBillingAccountResponse(BaseModel):
|
||||
id: int
|
||||
organization_id: int
|
||||
billing_mode: str
|
||||
cached_balance_credits: float
|
||||
currency: str
|
||||
|
||||
|
||||
class MPSCreditLedgerEntryResponse(BaseModel):
|
||||
id: int
|
||||
entry_type: str
|
||||
origin: Optional[str] = None
|
||||
credits_delta: float
|
||||
balance_after: float
|
||||
amount_minor: Optional[int] = None
|
||||
amount_currency: Optional[str] = None
|
||||
payment_order_id: Optional[int] = None
|
||||
metric_code: Optional[str] = None
|
||||
correlation_id: Optional[str] = None
|
||||
aggregation_key: Optional[str] = None
|
||||
usage_event_id: Optional[int] = None
|
||||
workflow_run_id: Optional[int] = None
|
||||
workflow_id: Optional[int] = None
|
||||
billable_quantity: Optional[float] = None
|
||||
quantity_unit: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str
|
||||
|
||||
|
||||
class MPSBillingCreditsResponse(BaseModel):
|
||||
billing_version: Literal["legacy", "v2"]
|
||||
total_credits_used: float = 0.0
|
||||
remaining_credits: float = 0.0
|
||||
total_quota: float = 0.0
|
||||
account: Optional[MPSBillingAccountResponse] = None
|
||||
ledger_entries: List[MPSCreditLedgerEntryResponse] = Field(default_factory=list)
|
||||
total_count: int = 0
|
||||
page: int = 1
|
||||
limit: int = 50
|
||||
total_pages: int = 0
|
||||
|
||||
|
||||
def _optional_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -97,7 +146,7 @@ class DailyUsageBreakdownResponse(BaseModel):
|
|||
|
||||
@router.get("/usage/current-period", response_model=CurrentUsageResponse)
|
||||
async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
||||
"""Get current billing period usage for the user's organization."""
|
||||
"""Get current reporting-period usage for the user's organization."""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
|
||||
|
|
@ -142,6 +191,202 @@ async def get_mps_credits(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _get_mps_billing_account_status(
|
||||
user: UserModel, organization_id: int
|
||||
) -> Optional[dict]:
|
||||
return await mps_service_key_client.get_billing_account_status(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
|
||||
|
||||
def _is_mps_billing_v2(account: Optional[dict]) -> bool:
|
||||
return bool(account and account.get("billing_mode") == "v2")
|
||||
|
||||
|
||||
async def _legacy_mps_credits_response(user: UserModel) -> MPSBillingCreditsResponse:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = float(usage.get("total_credits_used", 0.0))
|
||||
total_remaining = float(usage.get("remaining_credits", 0.0))
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="legacy",
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing/credits", response_model=MPSBillingCreditsResponse)
|
||||
async def get_billing_credits(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Return legacy MPS credits or paginated v2 billing ledger details for the org."""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss" or not user.selected_organization_id:
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
return await _legacy_mps_credits_response(user)
|
||||
|
||||
ledger = await mps_service_key_client.get_credit_ledger(
|
||||
organization_id=organization_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
account = ledger.get("account") or {}
|
||||
ledger_entries = ledger.get("ledger_entries") or []
|
||||
total_count = int(ledger.get("total_count") or len(ledger_entries))
|
||||
response_limit = int(ledger.get("limit") or limit)
|
||||
total_pages = int(
|
||||
ledger.get("total_pages")
|
||||
or ((total_count + response_limit - 1) // response_limit)
|
||||
)
|
||||
workflow_ids_by_run_id: dict[int, int] = {}
|
||||
workflow_run_ids = {
|
||||
workflow_run_id
|
||||
for entry in ledger_entries
|
||||
if (workflow_run_id := _optional_int(entry.get("workflow_run_id")))
|
||||
is not None
|
||||
}
|
||||
for workflow_run_id in workflow_run_ids:
|
||||
workflow_run = await db_client.get_workflow_run_by_id(workflow_run_id)
|
||||
if (
|
||||
workflow_run
|
||||
and workflow_run.workflow
|
||||
and workflow_run.workflow.organization_id == organization_id
|
||||
):
|
||||
workflow_ids_by_run_id[workflow_run_id] = workflow_run.workflow_id
|
||||
|
||||
balance = float(account.get("cached_balance_credits") or 0.0)
|
||||
total_debits = sum(
|
||||
abs(float(entry.get("credits_delta") or 0.0))
|
||||
for entry in ledger_entries
|
||||
if float(entry.get("credits_delta") or 0.0) < 0
|
||||
)
|
||||
if ledger.get("total_debits_credits") is not None:
|
||||
total_debits = float(ledger["total_debits_credits"])
|
||||
|
||||
return MPSBillingCreditsResponse(
|
||||
billing_version="v2",
|
||||
total_credits_used=total_debits,
|
||||
remaining_credits=balance,
|
||||
total_quota=balance + total_debits,
|
||||
account=MPSBillingAccountResponse(
|
||||
id=int(account["id"]),
|
||||
organization_id=int(account["organization_id"]),
|
||||
billing_mode=str(account["billing_mode"]),
|
||||
cached_balance_credits=balance,
|
||||
currency=str(account.get("currency") or "USD"),
|
||||
),
|
||||
ledger_entries=[
|
||||
MPSCreditLedgerEntryResponse(
|
||||
id=int(entry["id"]),
|
||||
entry_type=str(entry["entry_type"]),
|
||||
origin=entry.get("origin"),
|
||||
credits_delta=float(entry.get("credits_delta") or 0.0),
|
||||
balance_after=float(entry.get("balance_after") or 0.0),
|
||||
amount_minor=entry.get("amount_minor"),
|
||||
amount_currency=entry.get("amount_currency"),
|
||||
payment_order_id=entry.get("payment_order_id"),
|
||||
metric_code=entry.get("metric_code"),
|
||||
correlation_id=entry.get("correlation_id"),
|
||||
aggregation_key=entry.get("aggregation_key"),
|
||||
usage_event_id=_optional_int(entry.get("usage_event_id")),
|
||||
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
|
||||
workflow_id=workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None,
|
||||
billable_quantity=float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None,
|
||||
quantity_unit=entry.get("quantity_unit"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
)
|
||||
for entry in ledger_entries
|
||||
],
|
||||
total_count=total_count,
|
||||
page=int(ledger.get("page") or page),
|
||||
limit=response_limit,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to fetch billing credits: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/usage/mps-credits/purchase-url",
|
||||
response_model=MPSCreditPurchaseUrlResponse,
|
||||
)
|
||||
async def create_mps_credit_purchase_url(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
"""Create a checkout URL for organizations using Dograh-managed MPS v2."""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Credit purchases are not available in OSS mode",
|
||||
)
|
||||
|
||||
organization_id = user.selected_organization_id
|
||||
assert organization_id is not None
|
||||
account_status = await _get_mps_billing_account_status(user, organization_id)
|
||||
if not _is_mps_billing_v2(account_status):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
"Credit purchases are available only for organizations using billing v2"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
session = await mps_service_key_client.create_credit_purchase_url(
|
||||
organization_id=organization_id,
|
||||
created_by=str(user.provider_id),
|
||||
return_url=f"{UI_APP_URL.rstrip('/')}/billing",
|
||||
billing_details={
|
||||
"source": "dograh_billing",
|
||||
"dograh_user_id": str(user.id),
|
||||
"dograh_provider_id": str(user.provider_id),
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to create MPS credit purchase URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to create credit purchase URL",
|
||||
)
|
||||
|
||||
checkout_url = session.get("checkout_url")
|
||||
if not checkout_url:
|
||||
logger.error(f"MPS checkout session response missing checkout_url: {session}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="MPS checkout session response missing checkout_url",
|
||||
)
|
||||
return MPSCreditPurchaseUrlResponse(checkout_url=checkout_url)
|
||||
|
||||
|
||||
FILTERS_DESCRIPTION = """\
|
||||
JSON-encoded array of filter objects. Each object has the shape:
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import TriggerState, WorkflowStatus
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.factory import (
|
||||
get_default_telephony_provider,
|
||||
get_telephony_provider_by_id,
|
||||
|
|
@ -179,14 +179,6 @@ async def _execute_resolved_target(
|
|||
"""Shared execution path once the target workflow has been resolved."""
|
||||
execution_user_id = _get_execution_user_id(target.workflow)
|
||||
|
||||
# Check Dograh quota using the workflow owner's config and model overrides.
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id,
|
||||
workflow_id=target.workflow.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Get telephony provider — either the caller-specified config (validated
|
||||
# against the workflow's org) or the org's default config.
|
||||
if request.telephony_configuration_id is not None:
|
||||
|
|
@ -268,6 +260,15 @@ async def _execute_resolved_target(
|
|||
f"to phone number {request.phone_number}"
|
||||
)
|
||||
|
||||
# Check Dograh quota after the run exists so hosted v2 can mint and store
|
||||
# the MPS correlation id before the provider starts the call.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=target.workflow.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# 9. Construct webhook URL for telephony provider callback
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ They handle CORS, domain validation, and session management for embedded workflo
|
|||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
|
|
@ -16,6 +17,8 @@ from fastapi import (
|
|||
)
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode
|
||||
|
|
@ -27,6 +30,9 @@ from api.routes.turn_credentials import (
|
|||
|
||||
router = APIRouter(prefix="/public/embed")
|
||||
|
||||
EMBED_CORS_ALLOW_HEADERS = "Content-Type, Origin"
|
||||
EMBED_CORS_MAX_AGE = "86400"
|
||||
|
||||
|
||||
class InitEmbedRequest(BaseModel):
|
||||
"""Request model for initializing an embed session"""
|
||||
|
|
@ -70,11 +76,9 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|||
# If no domains specified, allow all origins
|
||||
return True
|
||||
|
||||
# Extract domain from origin (remove protocol)
|
||||
if "://" in origin:
|
||||
domain = origin.split("://")[1].split("/")[0].split(":")[0]
|
||||
else:
|
||||
domain = origin
|
||||
domain, origin_port = _parse_origin_host_port(origin)
|
||||
if not domain:
|
||||
return False
|
||||
|
||||
# Normalize domain for www matching
|
||||
def normalize_www(d: str) -> tuple[str, str]:
|
||||
|
|
@ -87,16 +91,23 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|||
domain_variants = normalize_www(domain)
|
||||
|
||||
for allowed in allowed_domains:
|
||||
allowed = str(allowed).strip().lower()
|
||||
if allowed == "*":
|
||||
return True
|
||||
elif allowed.startswith("*."):
|
||||
allowed_domain, allowed_port = _parse_origin_host_port(allowed)
|
||||
if not allowed_domain:
|
||||
continue
|
||||
if allowed_port is not None and allowed_port != origin_port:
|
||||
continue
|
||||
|
||||
if allowed_domain.startswith("*."):
|
||||
# Wildcard subdomain matching
|
||||
base_domain = allowed[2:]
|
||||
base_domain = allowed_domain[2:]
|
||||
if domain == base_domain or domain.endswith("." + base_domain):
|
||||
return True
|
||||
else:
|
||||
# Check both www and non-www versions
|
||||
allowed_variants = normalize_www(allowed)
|
||||
allowed_variants = normalize_www(allowed_domain)
|
||||
# If any variant of domain matches any variant of allowed, it's valid
|
||||
if any(
|
||||
dv in allowed_variants or av in domain_variants
|
||||
|
|
@ -108,6 +119,24 @@ def validate_origin(origin: str, allowed_domains: list) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _parse_origin_host_port(value: str) -> tuple[str, str | None]:
|
||||
candidate = value.strip().lower()
|
||||
if not candidate:
|
||||
return "", None
|
||||
|
||||
if "://" not in candidate and not candidate.startswith("//"):
|
||||
candidate = f"//{candidate}"
|
||||
|
||||
parsed = urlsplit(candidate)
|
||||
try:
|
||||
parsed_port = parsed.port
|
||||
except ValueError:
|
||||
parsed_port = None
|
||||
|
||||
port = str(parsed_port) if parsed_port is not None else None
|
||||
return (parsed.hostname or "").rstrip("."), port
|
||||
|
||||
|
||||
def generate_session_token() -> str:
|
||||
"""Generate a cryptographically secure session token"""
|
||||
return f"emb_session_{secrets.token_urlsafe(32)}"
|
||||
|
|
@ -121,8 +150,120 @@ def get_request_origin(request: Request) -> str:
|
|||
return origin
|
||||
|
||||
|
||||
def _cors_response(origin: str, methods: str) -> Response:
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": methods,
|
||||
"Access-Control-Allow-Headers": EMBED_CORS_ALLOW_HEADERS,
|
||||
"Access-Control-Max-Age": EMBED_CORS_MAX_AGE,
|
||||
"Vary": "Origin",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _allow_embed_origin(response: Response, origin: str) -> None:
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
vary = response.headers.get("Vary")
|
||||
if not vary:
|
||||
response.headers["Vary"] = "Origin"
|
||||
return
|
||||
|
||||
vary_values = {value.strip().lower() for value in vary.split(",")}
|
||||
if "origin" not in vary_values:
|
||||
response.headers["Vary"] = f"{vary}, Origin"
|
||||
|
||||
|
||||
async def _config_preflight_response(token: str, origin: str) -> Response:
|
||||
embed_token = await db_client.get_embed_token_by_token(token)
|
||||
if not embed_token or not embed_token.is_active:
|
||||
return Response(status_code=403)
|
||||
|
||||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
return Response(status_code=403)
|
||||
|
||||
return _cors_response(origin, "GET, OPTIONS")
|
||||
|
||||
|
||||
async def _turn_credentials_preflight_response(
|
||||
session_token: str, origin: str
|
||||
) -> Response:
|
||||
embed_session = await db_client.get_embed_session_by_token(session_token)
|
||||
if not embed_session:
|
||||
return Response(status_code=403)
|
||||
|
||||
if embed_session.expires_at and embed_session.expires_at < datetime.now(UTC):
|
||||
return Response(status_code=403)
|
||||
|
||||
embed_token = await db_client.get_embed_token_by_id(embed_session.embed_token_id)
|
||||
if not embed_token:
|
||||
return Response(status_code=403)
|
||||
|
||||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
return Response(status_code=403)
|
||||
|
||||
return _cors_response(origin, "GET, OPTIONS")
|
||||
|
||||
|
||||
async def build_public_embed_preflight_response(
|
||||
path: str, origin: str, requested_method: str, api_prefix: str = "/api/v1"
|
||||
) -> Response | None:
|
||||
"""Handle embed preflights before global CORSMiddleware rejects external sites."""
|
||||
public_embed_prefix = f"{api_prefix.rstrip('/')}/public/embed"
|
||||
|
||||
if path == f"{public_embed_prefix}/init":
|
||||
if requested_method.upper() != "POST":
|
||||
return Response(status_code=405)
|
||||
return _cors_response(origin, "POST, OPTIONS")
|
||||
|
||||
config_prefix = f"{public_embed_prefix}/config/"
|
||||
if path.startswith(config_prefix):
|
||||
if requested_method.upper() != "GET":
|
||||
return Response(status_code=405)
|
||||
token = path[len(config_prefix) :].split("/", 1)[0]
|
||||
return await _config_preflight_response(token, origin)
|
||||
|
||||
turn_credentials_prefix = f"{public_embed_prefix}/turn-credentials/"
|
||||
if path.startswith(turn_credentials_prefix):
|
||||
if requested_method.upper() != "GET":
|
||||
return Response(status_code=405)
|
||||
session_token = path[len(turn_credentials_prefix) :].split("/", 1)[0]
|
||||
return await _turn_credentials_preflight_response(session_token, origin)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class PublicEmbedCORSMiddleware:
|
||||
"""Allow token-gated embed CORS before global SaaS CORS rejects preflights."""
|
||||
|
||||
def __init__(self, app: ASGIApp, api_prefix: str = "/api/v1"):
|
||||
self.app = app
|
||||
self.api_prefix = api_prefix
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http" or scope.get("method") != "OPTIONS":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
requested_method = headers.get("access-control-request-method")
|
||||
|
||||
if origin and requested_method:
|
||||
response = await build_public_embed_preflight_response(
|
||||
scope.get("path", ""), origin, requested_method, self.api_prefix
|
||||
)
|
||||
if response is not None:
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
@router.post("/init", response_model=InitEmbedResponse)
|
||||
async def initialize_embed_session(request: Request, init_request: InitEmbedRequest):
|
||||
async def initialize_embed_session(
|
||||
request: Request, init_request: InitEmbedRequest, response: Response
|
||||
):
|
||||
"""Initialize an embed session with token validation and domain checking.
|
||||
|
||||
This endpoint:
|
||||
|
|
@ -158,6 +299,9 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
|
|||
)
|
||||
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
||||
|
||||
if origin:
|
||||
_allow_embed_origin(response, origin)
|
||||
|
||||
# Create workflow run
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
|
|
@ -204,8 +348,19 @@ async def initialize_embed_session(request: Request, init_request: InitEmbedRequ
|
|||
)
|
||||
|
||||
|
||||
@router.options("/config/{token}")
|
||||
async def options_embed_config(token: str, request: Request):
|
||||
"""Fallback OPTIONS handler for the embed config endpoint.
|
||||
|
||||
Browser preflights include Access-Control-Request-Method and are handled by
|
||||
PublicEmbedCORSMiddleware before global CORS. This keeps non-conformant
|
||||
OPTIONS requests on the same validation path.
|
||||
"""
|
||||
return await _config_preflight_response(token, request.headers.get("origin", ""))
|
||||
|
||||
|
||||
@router.get("/config/{token}", response_model=EmbedConfigResponse)
|
||||
async def get_embed_config(token: str, request: Request):
|
||||
async def get_embed_config(token: str, request: Request, response: Response):
|
||||
"""Get embed configuration without creating a session.
|
||||
|
||||
This endpoint is used to fetch widget configuration for display purposes
|
||||
|
|
@ -226,6 +381,11 @@ async def get_embed_config(token: str, request: Request):
|
|||
if not validate_origin(origin, embed_token.allowed_domains or []):
|
||||
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
||||
|
||||
# Set CORS header explicitly; the global CORSMiddleware covers only
|
||||
# first-party origins; this endpoint is fetched by external embed sites.
|
||||
if origin:
|
||||
_allow_embed_origin(response, origin)
|
||||
|
||||
# Extract settings with defaults
|
||||
settings = embed_token.settings or {}
|
||||
|
||||
|
|
@ -243,24 +403,20 @@ async def get_embed_config(token: str, request: Request):
|
|||
|
||||
@router.options("/init")
|
||||
async def options_init(request: Request):
|
||||
"""Handle CORS preflight for init endpoint"""
|
||||
"""Fallback OPTIONS handler for init endpoint."""
|
||||
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
|
||||
# For init endpoint, we need to check the token in the request body
|
||||
# But OPTIONS requests don't have body, so we'll be permissive
|
||||
# The actual validation happens in the POST request
|
||||
origin = request.headers.get("origin", "*")
|
||||
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Origin",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
)
|
||||
return _cors_response(origin, "POST, OPTIONS")
|
||||
|
||||
|
||||
@router.get("/turn-credentials/{session_token}", response_model=TurnCredentialsResponse)
|
||||
async def get_public_turn_credentials(session_token: str, request: Request):
|
||||
async def get_public_turn_credentials(
|
||||
session_token: str, request: Request, response: Response
|
||||
):
|
||||
"""Get TURN credentials for an embed session.
|
||||
|
||||
This endpoint allows embedded widgets to obtain TURN server credentials
|
||||
|
|
@ -295,6 +451,9 @@ async def get_public_turn_credentials(session_token: str, request: Request):
|
|||
)
|
||||
raise HTTPException(status_code=403, detail=f"Domain not allowed: {origin}")
|
||||
|
||||
if origin:
|
||||
_allow_embed_origin(response, origin)
|
||||
|
||||
# Check if TURN is configured
|
||||
if not TURN_SECRET:
|
||||
raise HTTPException(
|
||||
|
|
@ -316,63 +475,8 @@ async def get_public_turn_credentials(session_token: str, request: Request):
|
|||
|
||||
@router.options("/turn-credentials/{session_token}")
|
||||
async def options_turn_credentials(request: Request, session_token: str):
|
||||
"""Handle CORS preflight for TURN credentials endpoint"""
|
||||
origin = request.headers.get("origin", "*")
|
||||
|
||||
# Try to validate the session token and get allowed domains
|
||||
allowed_origin = origin
|
||||
try:
|
||||
embed_session = await db_client.get_embed_session_by_token(session_token)
|
||||
if embed_session:
|
||||
embed_token = await db_client.get_embed_token_by_id(
|
||||
embed_session.embed_token_id
|
||||
)
|
||||
if embed_token:
|
||||
# Check if origin is in allowed domains (empty means allow all)
|
||||
if validate_origin(origin, embed_token.allowed_domains or []):
|
||||
allowed_origin = origin
|
||||
else:
|
||||
allowed_origin = ""
|
||||
except Exception:
|
||||
# On error, be permissive for OPTIONS
|
||||
pass
|
||||
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": allowed_origin,
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.options("/config/{token}")
|
||||
async def options_config(request: Request, token: str):
|
||||
"""Handle CORS preflight for config endpoint"""
|
||||
# Get origin header
|
||||
origin = request.headers.get("origin", "*")
|
||||
|
||||
# Try to validate the token and get allowed domains
|
||||
allowed_origin = origin
|
||||
try:
|
||||
embed_token = await db_client.get_embed_token_by_token(token)
|
||||
if embed_token and embed_token.is_active:
|
||||
# Check if origin is in allowed domains
|
||||
if validate_origin(origin, embed_token.allowed_domains or []):
|
||||
allowed_origin = origin
|
||||
else:
|
||||
# If not allowed, don't include the origin
|
||||
allowed_origin = ""
|
||||
except Exception:
|
||||
# On error, be permissive for OPTIONS
|
||||
pass
|
||||
|
||||
return Response(
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": allowed_origin,
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
"Access-Control-Max-Age": "86400",
|
||||
}
|
||||
"""Fallback OPTIONS handler for TURN credentials endpoint."""
|
||||
# Browser preflights are handled by PublicEmbedCORSMiddleware before global CORS.
|
||||
return await _turn_credentials_preflight_response(
|
||||
session_token, request.headers.get("origin", "")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from api.enums import CallType, WorkflowRunState
|
|||
from api.errors.telephony_errors import TelephonyError
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.telephony.call_transfer_manager import get_call_transfer_manager
|
||||
from api.services.telephony.factory import (
|
||||
get_all_telephony_providers,
|
||||
|
|
@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
|
|||
workflow_run_id: int | None = None
|
||||
phone_number: str | None = None
|
||||
# Optional explicit telephony config to use for the test call. If omitted,
|
||||
# falls back to the user's per-user default (when set), then the org default.
|
||||
# falls back to the org default.
|
||||
telephony_configuration_id: int | None = None
|
||||
# Optional caller-ID phone number to dial out from. Must belong to the
|
||||
# resolved telephony configuration; otherwise the provider picks one.
|
||||
|
|
@ -82,7 +82,12 @@ async def initiate_call(
|
|||
"""Initiate a call using the configured telephony provider from web browser. This is
|
||||
supposed to be a test call method for the draft version of the agent."""
|
||||
|
||||
user_configuration = await db_client.get_user_configurations(user.id)
|
||||
from api.services.organization_preferences import get_organization_preferences
|
||||
|
||||
preferences = await get_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
db=db_client,
|
||||
)
|
||||
|
||||
# Resolve which telephony config to use: explicit request value, otherwise
|
||||
# the org's default outbound config.
|
||||
|
|
@ -116,13 +121,12 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
phone_number = request.phone_number or preferences.test_phone_number
|
||||
|
||||
if not phone_number:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Phone number must be provided in request or set in user "
|
||||
"configuration",
|
||||
detail="Phone number must be provided in request or set in organization preferences",
|
||||
)
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
|
|
@ -132,14 +136,6 @@ async def initiate_call(
|
|||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
execution_user_id = _get_execution_user_id(workflow)
|
||||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so the keys we will actually use are the ones checked).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
execution_user_id, workflow_id=workflow.id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
|
|
@ -182,6 +178,16 @@ async def initiate_call(
|
|||
)
|
||||
workflow_run_name = workflow_run.name
|
||||
|
||||
# Check Dograh quota after the run exists so hosted v2 can mint and store
|
||||
# the MPS correlation id before initiating the call.
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Construct webhook URL based on provider type
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
|
||||
|
|
@ -735,19 +741,8 @@ async def handle_inbound_run(request: Request):
|
|||
TelephonyError.SIGNATURE_VALIDATION_FAILED
|
||||
)
|
||||
|
||||
# 4. Quota check (use the workflow's model_overrides if set).
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# 5. Create workflow run + return provider-shaped response.
|
||||
# 5. Create workflow run + authorize quota before returning provider
|
||||
# stream instructions.
|
||||
workflow_run_id = await _create_inbound_workflow_run(
|
||||
workflow_id,
|
||||
user_id,
|
||||
|
|
@ -756,6 +751,17 @@ async def handle_inbound_run(request: Request):
|
|||
telephony_configuration_id=telephony_configuration_id,
|
||||
from_phone_number_id=phone_row.id,
|
||||
)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
websocket_url = (
|
||||
|
|
@ -870,20 +876,8 @@ async def handle_inbound_telephony(
|
|||
logger.error(f"Request validation failed: {error_type}")
|
||||
return provider_class.generate_validation_error_response(error_type)
|
||||
|
||||
# Check quota before processing (apply per-workflow model_overrides).
|
||||
# Create workflow run.
|
||||
user_id = workflow_context["user_id"]
|
||||
quota_result = await check_dograh_quota_by_user_id(
|
||||
user_id, workflow_id=workflow_id
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# Create workflow run
|
||||
workflow_run_id = await _create_inbound_workflow_run(
|
||||
workflow_id,
|
||||
workflow_context["user_id"],
|
||||
|
|
@ -892,6 +886,17 @@ async def handle_inbound_telephony(
|
|||
telephony_configuration_id=workflow_context["telephony_configuration_id"],
|
||||
from_phone_number_id=workflow_context.get("from_phone_number_id"),
|
||||
)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
logger.warning(
|
||||
f"User {user_id} has exceeded quota for inbound calls: {quota_result.error_message}"
|
||||
)
|
||||
return provider_class.generate_validation_error_response(
|
||||
TelephonyError.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
# Generate response URLs
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from api.db.models import (
|
|||
UserModel,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.check_validity import (
|
||||
APIKeyStatusResponse,
|
||||
UserConfigurationValidator,
|
||||
|
|
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
|
|||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import REGISTRY, ServiceType
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
|
||||
|
|
@ -94,8 +101,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
async def get_user_configurations(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> UserConfigurationRequestResponseSchema:
|
||||
user_configurations = await db_client.get_user_configurations(user.id)
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
masked_config = mask_user_config(resolved_config.effective)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -121,34 +137,61 @@ async def update_user_configurations(
|
|||
|
||||
# Remove organization_pricing from incoming dict as it's read-only
|
||||
incoming_dict.pop("organization_pricing", None)
|
||||
preferences_update = {
|
||||
key: incoming_dict.pop(key)
|
||||
for key in ("test_phone_number", "timezone")
|
||||
if key in incoming_dict
|
||||
}
|
||||
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(existing_config, incoming_dict)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
if incoming_dict:
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(
|
||||
existing_config, incoming_dict
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
else:
|
||||
user_configurations = existing_config
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
if user.selected_organization_id and preferences_update:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if "test_phone_number" in preferences_update:
|
||||
preferences.test_phone_number = preferences_update["test_phone_number"]
|
||||
if "timezone" in preferences_update:
|
||||
preferences.timezone = preferences_update["timezone"]
|
||||
await upsert_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
preferences,
|
||||
)
|
||||
|
||||
# Return masked version of updated config
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -168,7 +211,11 @@ async def validate_user_configurations(
|
|||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> APIKeyStatusResponse:
|
||||
configurations = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
configurations = resolved_config.effective
|
||||
|
||||
if (
|
||||
configurations.last_validated_at
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from api.services.pipecat.ws_sender_registry import (
|
|||
register_ws_sender,
|
||||
unregister_ws_sender,
|
||||
)
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
|
||||
router = APIRouter(prefix="/ws")
|
||||
|
||||
|
|
@ -329,7 +329,11 @@ class SignalingManager:
|
|||
|
||||
# Check Dograh quota before initiating the call (apply per-workflow
|
||||
# model_overrides so we evaluate the keys this workflow will use).
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
# Send error response for quota issues
|
||||
await ws.send_json(
|
||||
|
|
|
|||
|
|
@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
|
|||
from api.db.models import UserModel
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_configurations,
|
||||
|
|
@ -32,12 +41,15 @@ from api.services.configuration.resolve import (
|
|||
)
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.pricing.run_usage_response import format_public_usage_info
|
||||
from api.services.reports import generate_workflow_report_csv
|
||||
from api.services.storage import storage_fs
|
||||
from api.services.workflow.dto import ReactFlowDTO, sanitize_workflow_definition
|
||||
from api.services.workflow.duplicate import duplicate_workflow
|
||||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.run_usage_response import (
|
||||
format_public_cost_info,
|
||||
format_public_usage_info,
|
||||
)
|
||||
from api.services.workflow.trigger_paths import (
|
||||
TriggerPathIssue,
|
||||
ensure_trigger_paths,
|
||||
|
|
@ -955,12 +967,74 @@ async def update_workflow(
|
|||
existing_def,
|
||||
)
|
||||
|
||||
# Validate model_overrides: resolve onto global config, then
|
||||
# run the same validator used by the user-configurations endpoint.
|
||||
# Also stamp the current global API key into the override so the override
|
||||
# remains functional if the global config later switches to a different provider.
|
||||
# Validate model overrides. v2 uses a complete workflow-level model
|
||||
# configuration; legacy v1 uses partial service overlays.
|
||||
workflow_configurations = request.workflow_configurations
|
||||
if workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
if workflow_configurations and workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
existing_draft = await db_client.get_draft_version(workflow_id)
|
||||
existing_configs = (
|
||||
existing_draft.workflow_configurations
|
||||
if existing_draft
|
||||
else existing_workflow.released_definition.workflow_configurations
|
||||
)
|
||||
existing_v2_override = (existing_configs or {}).get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
try:
|
||||
incoming_v2_override = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
workflow_configurations[
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
]
|
||||
)
|
||||
)
|
||||
existing_v2_override_config = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
existing_v2_override
|
||||
)
|
||||
if existing_v2_override
|
||||
else None
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
incoming_v2_override,
|
||||
existing_v2_override_config,
|
||||
)
|
||||
if existing_v2_override_config is None:
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
v2_override,
|
||||
resolved_config.organization_configuration,
|
||||
)
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
|
||||
effective = compile_ai_model_configuration_v2(v2_override)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except (ValidationError, ValueError) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
elif workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
@ -978,24 +1052,48 @@ async def update_workflow(
|
|||
workflow_configurations,
|
||||
existing_configs,
|
||||
)
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
effective_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
effective_config,
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
effective = resolve_effective_config(
|
||||
effective_config, enriched_overrides
|
||||
)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
compile_ai_model_configuration_v2(v2_override),
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
else:
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
if resolved_config.source == "organization_v2":
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
else:
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
|
||||
# Reject upfront if any new trigger path collides with another
|
||||
# workflow's trigger — keeps the workflow record from
|
||||
|
|
@ -1171,22 +1269,7 @@ async def get_workflow_run(
|
|||
"transcript_public_url": artifact_url(public_access_token, "transcript"),
|
||||
"recording_public_url": artifact_url(public_access_token, "recording"),
|
||||
"public_access_token": public_access_token,
|
||||
"cost_info": {
|
||||
"dograh_token_usage": (
|
||||
run.cost_info.get("dograh_token_usage")
|
||||
if run.cost_info and "dograh_token_usage" in run.cost_info
|
||||
else round(float(run.cost_info.get("total_cost_usd", 0)) * 100, 2)
|
||||
if run.cost_info and "total_cost_usd" in run.cost_info
|
||||
else 0
|
||||
),
|
||||
"call_duration_seconds": int(
|
||||
round(run.cost_info.get("call_duration_seconds"))
|
||||
)
|
||||
if run.cost_info and run.cost_info.get("call_duration_seconds") is not None
|
||||
else None,
|
||||
}
|
||||
if run.cost_info
|
||||
else None,
|
||||
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
"created_at": run.created_at,
|
||||
"definition_id": run.definition_id,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from pydantic import BaseModel, Field
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.auth.depends import get_user_with_selected_organization
|
||||
from api.services.quota_service import authorize_workflow_run_start
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
TextChatSessionExecutionError,
|
||||
|
|
@ -96,14 +96,16 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _require_selected_organization_id(user: UserModel) -> int:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
async def _ensure_text_chat_quota(
|
||||
user: UserModel,
|
||||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
) -> None:
|
||||
quota_result = await authorize_workflow_run_start(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
actor_user=user,
|
||||
)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
|
|
@ -114,9 +116,8 @@ async def _load_text_session_or_404(
|
|||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=organization_id
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
|
|
@ -158,11 +159,8 @@ async def _execute_pending_turn_response(
|
|||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
try:
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
|
|
@ -172,12 +170,13 @@ async def create_text_chat_session(
|
|||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
organization_id=organization_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
set_current_run_id(workflow_run.id)
|
||||
await _ensure_text_chat_quota(user, workflow_id, workflow_run.id)
|
||||
|
||||
annotations = {
|
||||
"tester": {
|
||||
|
|
@ -220,7 +219,7 @@ async def create_text_chat_session(
|
|||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
|
@ -234,10 +233,10 @@ async def append_text_chat_message(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
await _ensure_text_chat_quota(user, workflow_id, run_id)
|
||||
|
||||
try:
|
||||
text_session = await append_text_chat_user_message(
|
||||
|
|
@ -264,7 +263,7 @@ async def rewind_text_chat_session(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue