feat: add config v2 to simplify billing (#428)

* feat: add model config v2

* chore: centralize user org selection

* chore: move preferences to platform settings

* fix: decouple org preference and ai model preferences
This commit is contained in:
Abhishek 2026-06-09 16:10:26 +05:30 committed by GitHub
parent 49e68b49d5
commit cdbd06c8d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 5135 additions and 264 deletions

View file

@ -10,6 +10,7 @@ from sqlalchemy.orm import joinedload
from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters
from api.db.models import (
OrganizationConfigurationModel,
OrganizationModel,
OrganizationUsageCycleModel,
UserConfigurationModel,
@ -17,6 +18,7 @@ from api.db.models import (
WorkflowModel,
WorkflowRunModel,
)
from api.enums import OrganizationConfigurationKey
from api.schemas.user_configuration import UserConfiguration
@ -440,8 +442,29 @@ class OrganizationUsageClient(BaseDBClient):
"""Get daily usage breakdown for an organization with pricing."""
async with self.async_session() as session:
# Get user timezone if user_id is provided
# Get org timezone preference first, then fall back to legacy user config.
user_timezone = "UTC" # Default timezone
pref_result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key.in_(
[
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
]
),
)
)
pref_rows = pref_result.scalars().all()
pref_by_key = {pref.key: pref for pref in pref_rows}
pref_obj = pref_by_key.get(
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value
) or pref_by_key.get(
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value
)
if pref_obj and pref_obj.value:
user_timezone = pref_obj.value.get("timezone") or user_timezone
if user_id:
config_result = await session.execute(
select(UserConfigurationModel).where(
@ -453,7 +476,7 @@ class OrganizationUsageClient(BaseDBClient):
user_config = UserConfiguration.model_validate(
config_obj.configuration
)
if user_config.timezone:
if user_config.timezone and user_timezone == "UTC":
user_timezone = user_config.timezone
# Validate timezone string

View file

@ -89,6 +89,11 @@ class OrganizationConfigurationKey(Enum):
LANGFUSE_CREDENTIALS = (
"LANGFUSE_CREDENTIALS" # Org-level Langfuse tracing credentials
)
MODEL_CONFIGURATION_V2 = (
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
)
ORGANIZATION_PREFERENCES = "ORGANIZATION_PREFERENCES" # Org-level defaults such as timezone/test call number
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
class WorkflowStatus(Enum):

View file

@ -3,9 +3,12 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
from api.services.posthog_client import capture_event
from api.utils.auth import create_jwt_token, hash_password, verify_password
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
)
if mps_config:
await db_client.update_user_configuration(user.id, mps_config)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception:
logger.warning(
"Failed to create default configuration for OSS user", exc_info=True

View file

@ -369,6 +369,10 @@ async def search_chunks(
try:
# Import here to avoid circular dependency
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
from api.services.configuration.registry import ServiceProviders
from api.services.gen_ai import (
AzureOpenAIEmbeddingService,
@ -376,10 +380,15 @@ async def search_chunks(
)
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
embeddings_api_key = None
embeddings_model = None
embeddings_provider = None
embeddings_base_url = None
embeddings_endpoint = None
embeddings_api_version = None
@ -388,6 +397,10 @@ async def search_chunks(
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
)
@ -406,9 +419,7 @@ async def search_chunks(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=getattr(user_config.embeddings, "base_url", None)
if user_config.embeddings
else None,
base_url=embeddings_base_url,
)
# Perform search

View file

@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
@ -10,6 +10,14 @@ from api.db import db_client
from api.db.models import UserModel
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
)
from api.schemas.organization_preferences import OrganizationPreferences
from api.schemas.telephony_config import (
TelephonyConfigRequest,
TelephonyConfigurationCreateRequest,
@ -26,8 +34,31 @@ from api.schemas.telephony_phone_number import (
PhoneNumberUpdateRequest,
ProviderSyncStatus,
)
from api.services.auth.depends import get_user
from api.services.configuration.masking import is_mask_of, mask_key
from api.services.auth.depends import get_user, get_user_with_selected_organization
from api.services.configuration.ai_model_configuration import (
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_organization_ai_model_configuration_v2,
get_resolved_ai_model_configuration,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_model_configurations_to_v2,
upsert_organization_ai_model_configuration_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
from api.services.configuration.registry import (
DOGRAH_STT_LANGUAGES,
REGISTRY,
ServiceProviders,
ServiceType,
)
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
from api.services.posthog_client import capture_event
from api.services.telephony import registry as telephony_registry
from api.services.telephony.factory import get_telephony_provider_by_id
@ -159,6 +190,222 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
)
# ---------------------------------------------------------------------------
# AI model configurations v2
# ---------------------------------------------------------------------------
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
return {
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[service_type].items()
if provider != ServiceProviders.DOGRAH.value
}
async def _model_configuration_v2_response(
*,
user: UserModel,
configuration: OrganizationAIModelConfigurationV2 | None = None,
) -> OrganizationAIModelConfigurationResponse:
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
raw_configuration = (
configuration
if configuration is not None
else resolved.organization_configuration
)
return OrganizationAIModelConfigurationResponse(
configuration=mask_ai_model_configuration_v2(raw_configuration),
effective_configuration=mask_user_config(resolved.effective),
source=resolved.source,
)
@router.get("/model-configurations/v2/defaults")
async def get_model_configuration_v2_defaults(
user: UserModel = Depends(get_user_with_selected_organization),
):
byok_default_providers = {
service: provider
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
if provider != ServiceProviders.DOGRAH.value
}
return {
"dograh": {
"voices": [DOGRAH_DEFAULT_VOICE],
"speeds": list(DOGRAH_SPEED_OPTIONS),
"languages": DOGRAH_STT_LANGUAGES,
"defaults": {
"voice": DOGRAH_DEFAULT_VOICE,
"speed": 1.0,
"language": DOGRAH_DEFAULT_LANGUAGE,
},
},
"byok": {
"pipeline": {
"llm": _byok_provider_schemas(ServiceType.LLM),
"tts": _byok_provider_schemas(ServiceType.TTS),
"stt": _byok_provider_schemas(ServiceType.STT),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
"realtime": {
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
"llm": _byok_provider_schemas(ServiceType.LLM),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
},
}
@router.get(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def get_model_configuration_v2(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await _model_configuration_v2_response(user=user)
@router.put(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def save_model_configuration_v2(
request: OrganizationAIModelConfigurationV2,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
existing = await get_organization_ai_model_configuration_v2(organization_id)
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
try:
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/model-configurations/v2/migration-preview")
async def preview_model_configuration_v2_migration(
user: UserModel = Depends(get_user_with_selected_organization),
):
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc))
return {
"configuration": mask_ai_model_configuration_v2(configuration),
"effective_configuration": mask_user_config(
compile_ai_model_configuration_v2(configuration)
),
}
@router.post(
"/model-configurations/v2/migrate",
response_model=OrganizationAIModelConfigurationResponse,
)
async def migrate_model_configuration_v2(
force: bool = Query(default=False),
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
existing = await get_organization_ai_model_configuration_v2(organization_id)
if existing is not None and not force:
raise HTTPException(
status_code=409,
detail="Organization already has a v2 model configuration",
)
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
await migrate_workflow_model_configurations_to_v2(
organization_id=organization_id,
fallback_user_config=legacy,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/preferences", response_model=OrganizationPreferences)
async def get_preferences(
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await get_organization_preferences(organization_id)
@router.put("/preferences", response_model=OrganizationPreferences)
async def save_preferences(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await upsert_organization_preferences(
organization_id,
request,
)
@router.get(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def get_model_configuration_preferences_legacy(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await get_preferences(user=user)
@router.put(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def save_model_configuration_preferences_legacy(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
return await save_preferences(request=request, user=user)
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
"""If the client re-submitted a masked sensitive field, restore the original."""
for field_name in _sensitive_fields(provider):

View file

@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
workflow_run_id: int | None = None
phone_number: str | None = None
# Optional explicit telephony config to use for the test call. If omitted,
# falls back to the user's per-user default (when set), then the org default.
# falls back to the org default.
telephony_configuration_id: int | None = None
# Optional caller-ID phone number to dial out from. Must belong to the
# resolved telephony configuration; otherwise the provider picks one.
@ -82,7 +82,12 @@ async def initiate_call(
"""Initiate a call using the configured telephony provider from web browser. This is
supposed to be a test call method for the draft version of the agent."""
user_configuration = await db_client.get_user_configurations(user.id)
from api.services.organization_preferences import get_organization_preferences
preferences = await get_organization_preferences(
user.selected_organization_id,
db=db_client,
)
# Resolve which telephony config to use: explicit request value, otherwise
# the org's default outbound config.
@ -116,13 +121,12 @@ async def initiate_call(
detail="telephony_not_configured",
)
phone_number = request.phone_number or user_configuration.test_phone_number
phone_number = request.phone_number or preferences.test_phone_number
if not phone_number:
raise HTTPException(
status_code=400,
detail="Phone number must be provided in request or set in user "
"configuration",
detail="Phone number must be provided in request or set in organization preferences",
)
workflow = await db_client.get_workflow(

View file

@ -10,6 +10,9 @@ from api.db.models import (
UserModel,
)
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
get_resolved_ai_model_configuration,
)
from api.services.configuration.check_validity import (
APIKeyStatusResponse,
UserConfigurationValidator,
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
from api.services.configuration.merge import merge_user_configurations
from api.services.configuration.registry import REGISTRY, ServiceType
from api.services.mps_service_key_client import mps_service_key_client
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
router = APIRouter(prefix="/user")
@ -91,8 +98,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
async def get_user_configurations(
user: UserModel = Depends(get_user),
) -> UserConfigurationRequestResponseSchema:
user_configurations = await db_client.get_user_configurations(user.id)
masked_config = mask_user_config(user_configurations)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
masked_config = mask_user_config(resolved_config.effective)
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -118,34 +134,61 @@ async def update_user_configurations(
# Remove organization_pricing from incoming dict as it's read-only
incoming_dict.pop("organization_pricing", None)
preferences_update = {
key: incoming_dict.pop(key)
for key in ("test_phone_number", "timezone")
if key in incoming_dict
}
# Merge via helper
try:
user_configurations = merge_user_configurations(existing_config, incoming_dict)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
if incoming_dict:
# Merge via helper
try:
user_configurations = merge_user_configurations(
existing_config, incoming_dict
)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
else:
user_configurations = existing_config
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
if user.selected_organization_id and preferences_update:
preferences = await get_organization_preferences(user.selected_organization_id)
if "test_phone_number" in preferences_update:
preferences.test_phone_number = preferences_update["test_phone_number"]
if "timezone" in preferences_update:
preferences.timezone = preferences_update["timezone"]
await upsert_organization_preferences(
user.selected_organization_id,
preferences,
)
# Return masked version of updated config
masked_config = mask_user_config(user_configurations)
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -165,7 +208,11 @@ async def validate_user_configurations(
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
user: UserModel = Depends(get_user),
) -> APIKeyStatusResponse:
configurations = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
configurations = resolved_config.effective
if (
configurations.last_validated_at

View file

@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
from api.db.models import UserModel
from api.db.workflow_template_client import WorkflowTemplateClient
from api.enums import CallType, PostHogEvent, StorageBackend
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
from api.schemas.workflow import WorkflowRunResponseSchema
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_resolved_ai_model_configuration,
merge_ai_model_configuration_v2_secrets,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import (
mask_workflow_configurations,
@ -955,12 +964,74 @@ async def update_workflow(
existing_def,
)
# Validate model_overrides: resolve onto global config, then
# run the same validator used by the user-configurations endpoint.
# Also stamp the current global API key into the override so the override
# remains functional if the global config later switches to a different provider.
# Validate model overrides. v2 uses a complete workflow-level model
# configuration; legacy v1 uses partial service overlays.
workflow_configurations = request.workflow_configurations
if workflow_configurations and workflow_configurations.get("model_overrides"):
if workflow_configurations and workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
if existing_workflow is None:
raise HTTPException(
status_code=404, detail=f"Workflow with id {workflow_id} not found"
)
existing_draft = await db_client.get_draft_version(workflow_id)
existing_configs = (
existing_draft.workflow_configurations
if existing_draft
else existing_workflow.released_definition.workflow_configurations
)
existing_v2_override = (existing_configs or {}).get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
try:
incoming_v2_override = (
OrganizationAIModelConfigurationV2.model_validate(
workflow_configurations[
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
]
)
)
existing_v2_override_config = (
OrganizationAIModelConfigurationV2.model_validate(
existing_v2_override
)
if existing_v2_override
else None
)
v2_override = merge_ai_model_configuration_v2_secrets(
incoming_v2_override,
existing_v2_override_config,
)
if existing_v2_override_config is None:
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
v2_override = merge_ai_model_configuration_v2_secrets(
v2_override,
resolved_config.organization_configuration,
)
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
effective = compile_ai_model_configuration_v2(v2_override)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except (ValidationError, ValueError) as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
elif workflow_configurations and workflow_configurations.get("model_overrides"):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
@ -978,24 +1049,46 @@ async def update_workflow(
workflow_configurations,
existing_configs,
)
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
try:
enriched_overrides = enrich_overrides_with_api_keys(
workflow_configurations["model_overrides"],
user_config,
)
effective = resolve_effective_config(user_config, enriched_overrides)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
if resolved_config.source == "organization_v2":
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
await UserConfigurationValidator().validate(
compile_ai_model_configuration_v2(v2_override),
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
else:
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
if resolved_config.source == "organization_v2":
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
else:
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
# Reject upfront if any new trigger path collides with another
# workflow's trigger — keeps the workflow record from

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
from api.db import db_client
from api.db.models import UserModel, WorkflowRunTextSessionModel
from api.enums import WorkflowRunMode
from api.services.auth.depends import get_user
from api.services.auth.depends import get_user_with_selected_organization
from api.services.quota_service import check_dograh_quota
from api.services.workflow.text_chat_session_service import (
TextChatPendingTurnLostError,
@ -96,12 +96,6 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
}
def _require_selected_organization_id(user: UserModel) -> int:
if user.selected_organization_id is None:
raise HTTPException(status_code=403, detail="Organization context is required")
return user.selected_organization_id
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
if not quota_result.has_quota:
@ -114,9 +108,8 @@ async def _load_text_session_or_404(
user: UserModel,
) -> WorkflowRunTextSessionModel:
set_current_run_id(run_id)
organization_id = _require_selected_organization_id(user)
text_session = await db_client.get_workflow_run_text_session(
run_id, organization_id=organization_id
run_id, organization_id=user.selected_organization_id
)
if not text_session or not text_session.workflow_run:
raise HTTPException(status_code=404, detail="Text chat session not found")
@ -158,9 +151,8 @@ async def _execute_pending_turn_response(
async def create_text_chat_session(
workflow_id: int,
request: CreateTextChatSessionRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
organization_id = _require_selected_organization_id(user)
await _ensure_text_chat_quota(user, workflow_id)
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
@ -172,7 +164,7 @@ async def create_text_chat_session(
user_id=user.id,
initial_context=request.initial_context,
use_draft=True,
organization_id=organization_id,
organization_id=user.selected_organization_id,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@ -220,7 +212,7 @@ async def create_text_chat_session(
async def get_text_chat_session(
workflow_id: int,
run_id: int,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
return _build_response(text_session)
@ -234,7 +226,7 @@ async def append_text_chat_message(
workflow_id: int,
run_id: int,
request: AppendTextChatMessageRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
await _ensure_text_chat_quota(user, workflow_id)
@ -264,7 +256,7 @@ async def rewind_text_chat_session(
workflow_id: int,
run_id: int,
request: RewindTextChatSessionRequest,
user: UserModel = Depends(get_user),
user: UserModel = Depends(get_user_with_selected_organization),
) -> WorkflowRunTextSessionResponse:
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
try:

View file

@ -0,0 +1,170 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
DograhEmbeddingsConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
EmbeddingsConfig,
LLMConfig,
RealtimeConfig,
ServiceProviders,
STTConfig,
TTSConfig,
)
DOGRAH_SPEED_OPTIONS: tuple[float, ...] = (0.8, 1.0, 1.2)
DOGRAH_DEFAULT_VOICE = "default"
DOGRAH_DEFAULT_LANGUAGE = "multi"
class DograhManagedAIModelConfiguration(BaseModel):
api_key: str
voice: str = DOGRAH_DEFAULT_VOICE
speed: float = Field(default=1.0)
language: str = DOGRAH_DEFAULT_LANGUAGE
@model_validator(mode="after")
def validate_speed(self):
if self.speed not in DOGRAH_SPEED_OPTIONS:
allowed = ", ".join(str(speed) for speed in DOGRAH_SPEED_OPTIONS)
raise ValueError(f"Dograh speed must be one of: {allowed}")
return self
class BYOKPipelineAIModelConfiguration(BaseModel):
llm: LLMConfig
tts: TTSConfig
stt: STTConfig
embeddings: EmbeddingsConfig | None = None
@model_validator(mode="after")
def reject_dograh_providers(self):
_reject_dograh_provider("llm", self.llm)
_reject_dograh_provider("tts", self.tts)
_reject_dograh_provider("stt", self.stt)
_reject_dograh_provider("embeddings", self.embeddings)
return self
class BYOKRealtimeAIModelConfiguration(BaseModel):
realtime: RealtimeConfig
llm: LLMConfig
embeddings: EmbeddingsConfig | None = None
@model_validator(mode="after")
def reject_dograh_providers(self):
_reject_dograh_provider("llm", self.llm)
_reject_dograh_provider("embeddings", self.embeddings)
return self
class BYOKAIModelConfiguration(BaseModel):
mode: Literal["pipeline", "realtime"]
pipeline: BYOKPipelineAIModelConfiguration | None = None
realtime: BYOKRealtimeAIModelConfiguration | None = None
@model_validator(mode="after")
def validate_selected_mode(self):
if self.mode == "pipeline" and self.pipeline is None:
raise ValueError("byok.pipeline is required when byok.mode is pipeline")
if self.mode == "realtime" and self.realtime is None:
raise ValueError("byok.realtime is required when byok.mode is realtime")
return self
class OrganizationAIModelConfigurationV2(BaseModel):
version: Literal[2] = 2
mode: Literal["dograh", "byok"]
dograh: DograhManagedAIModelConfiguration | None = None
byok: BYOKAIModelConfiguration | None = None
@model_validator(mode="after")
def validate_selected_mode(self):
if self.mode == "dograh" and self.dograh is None:
raise ValueError("dograh configuration is required when mode is dograh")
if self.mode == "byok" and self.byok is None:
raise ValueError("byok configuration is required when mode is byok")
return self
class OrganizationAIModelConfigurationResponse(BaseModel):
configuration: dict | None
effective_configuration: dict
source: Literal["organization_v2", "legacy_user_v1", "empty"]
def compile_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> EffectiveAIModelConfiguration:
if configuration.mode == "dograh":
if configuration.dograh is None:
raise ValueError("dograh configuration is required")
return _compile_dograh_configuration(configuration.dograh)
if configuration.byok is None:
raise ValueError("byok configuration is required")
if configuration.byok.mode == "pipeline":
if configuration.byok.pipeline is None:
raise ValueError("byok.pipeline is required")
pipeline = configuration.byok.pipeline
return EffectiveAIModelConfiguration(
llm=pipeline.llm,
tts=pipeline.tts,
stt=pipeline.stt,
embeddings=pipeline.embeddings,
is_realtime=False,
)
if configuration.byok.realtime is None:
raise ValueError("byok.realtime is required")
realtime = configuration.byok.realtime
return EffectiveAIModelConfiguration(
llm=realtime.llm,
realtime=realtime.realtime,
embeddings=realtime.embeddings,
is_realtime=True,
)
def _compile_dograh_configuration(
configuration: DograhManagedAIModelConfiguration,
) -> EffectiveAIModelConfiguration:
return EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
),
tts=DograhTTSService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
voice=configuration.voice,
speed=configuration.speed,
),
stt=DograhSTTService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
language=configuration.language,
),
embeddings=DograhEmbeddingsConfiguration(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
),
is_realtime=False,
)
def _reject_dograh_provider(section: str, service) -> None:
if service is None:
return
if getattr(service, "provider", None) == ServiceProviders.DOGRAH:
raise ValueError(f"BYOK {section} cannot use Dograh provider")

View file

@ -0,0 +1,6 @@
from pydantic import BaseModel
class OrganizationPreferences(BaseModel):
test_phone_number: str | None = None
timezone: str | None = None

View file

@ -11,7 +11,7 @@ from api.services.configuration.registry import (
)
class UserConfiguration(BaseModel):
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
@ -31,3 +31,7 @@ class UserConfiguration(BaseModel):
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data
# Backward-compatible alias for legacy persistence and existing call sites.
UserConfiguration = EffectiveAIModelConfiguration

View file

@ -1,7 +1,7 @@
from typing import Annotated, Optional
import httpx
from fastapi import Header, HTTPException, Query, WebSocket
from fastapi import Depends, Header, HTTPException, Query, WebSocket
from loguru import logger
from pydantic import ValidationError
@ -119,6 +119,19 @@ async def get_user(
await db_client.update_user_configuration(
user_model.id, mps_config
)
from api.enums import OrganizationConfigurationKey
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
mps_config
)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception as exc:
raise HTTPException(
@ -129,6 +142,14 @@ async def get_user(
return user_model
async def get_user_with_selected_organization(
user: Annotated[UserModel, Depends(get_user)],
) -> UserModel:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
return user
async def _handle_oss_auth(authorization: str | None) -> UserModel:
"""
Handle authentication for OSS deployment mode.

View file

@ -0,0 +1,484 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Literal
from loguru import logger
from pydantic import ValidationError
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from api.constants import MPS_API_URL
from api.db import db_client
from api.db.models import WorkflowDefinitionModel, WorkflowModel
from api.enums import OrganizationConfigurationKey
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
BYOKAIModelConfiguration,
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,
mask_key,
resolve_masked_api_keys,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
@dataclass
class ResolvedAIModelConfiguration:
effective: EffectiveAIModelConfiguration
source: AIModelConfigurationSource
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
@dataclass
class WorkflowAIModelConfigurationMigrationResult:
workflow_count: int = 0
definition_count: int = 0
workflow_ids: list[int] | None = None
async def get_resolved_ai_model_configuration(
*,
user_id: int | None,
organization_id: int | None,
) -> ResolvedAIModelConfiguration:
organization_configuration = await get_organization_ai_model_configuration_v2(
organization_id
)
if organization_configuration is not None:
return ResolvedAIModelConfiguration(
effective=compile_ai_model_configuration_v2(organization_configuration),
source="organization_v2",
organization_configuration=organization_configuration,
)
if user_id is None:
return ResolvedAIModelConfiguration(
effective=EffectiveAIModelConfiguration(),
source="empty",
)
legacy = await db_client.get_user_configurations(user_id)
return ResolvedAIModelConfiguration(
effective=legacy,
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
)
async def get_effective_ai_model_configuration_for_workflow(
*,
user_id: int | None,
organization_id: int | None,
workflow_configurations: dict | None,
) -> EffectiveAIModelConfiguration:
workflow_configurations = workflow_configurations or {}
v2_override = workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
if v2_override:
return compile_ai_model_configuration_v2(
OrganizationAIModelConfigurationV2.model_validate(v2_override)
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user_id,
organization_id=organization_id,
)
return resolve_effective_config(
resolved_config.effective,
workflow_configurations.get("model_overrides"),
)
async def get_organization_ai_model_configuration_v2(
organization_id: int | None,
) -> OrganizationAIModelConfigurationV2 | None:
if organization_id is None:
return None
row = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
)
if row is None or not row.value:
return None
try:
return OrganizationAIModelConfigurationV2.model_validate(row.value)
except ValidationError as exc:
logger.warning(
"Invalid org AI model configuration v2 for organization "
f"{organization_id}: {exc}. Falling back to legacy configuration."
)
return None
async def upsert_organization_ai_model_configuration_v2(
organization_id: int,
configuration: OrganizationAIModelConfigurationV2,
) -> OrganizationAIModelConfigurationV2:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
configuration.model_dump(mode="json", exclude_none=True),
)
return configuration
async def migrate_workflow_model_configurations_to_v2(
*,
organization_id: int,
fallback_user_config: EffectiveAIModelConfiguration,
) -> WorkflowAIModelConfigurationMigrationResult:
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
workflow_updates: list[tuple[int, dict]] = []
definition_updates: list[tuple[int, dict]] = []
migrated_workflow_ids: set[int] = set()
for workflow in workflows:
base_config = fallback_user_config
if workflow.user_id is not None:
if workflow.user_id not in owner_configs:
owner_configs[
workflow.user_id
] = await db_client.get_user_configurations(workflow.user_id)
base_config = owner_configs[workflow.user_id]
workflow_configs, workflow_changed = (
migrate_workflow_configuration_model_override_to_v2(
workflow.workflow_configurations,
base_config,
)
)
if workflow_changed:
workflow_updates.append((workflow.id, workflow_configs))
migrated_workflow_ids.add(workflow.id)
for definition in workflow.definitions:
definition_configs, definition_changed = (
migrate_workflow_configuration_model_override_to_v2(
definition.workflow_configurations,
base_config,
)
)
if definition_changed:
definition_updates.append((definition.id, definition_configs))
migrated_workflow_ids.add(workflow.id)
if workflow_updates or definition_updates:
async with db_client.async_session() as session:
for workflow_id, workflow_configs in workflow_updates:
await session.execute(
update(WorkflowModel)
.where(WorkflowModel.id == workflow_id)
.values(workflow_configurations=workflow_configs)
)
for definition_id, definition_configs in definition_updates:
await session.execute(
update(WorkflowDefinitionModel)
.where(WorkflowDefinitionModel.id == definition_id)
.values(workflow_configurations=definition_configs)
)
await session.commit()
return WorkflowAIModelConfigurationMigrationResult(
workflow_count=len(migrated_workflow_ids),
definition_count=len(definition_updates),
workflow_ids=sorted(migrated_workflow_ids),
)
def migrate_workflow_configuration_model_override_to_v2(
workflow_configurations: dict | None,
base_config: EffectiveAIModelConfiguration,
) -> tuple[dict, bool]:
if not isinstance(workflow_configurations, dict):
return {}, False
migrated = copy.deepcopy(workflow_configurations)
model_overrides = migrated.get("model_overrides")
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
if not isinstance(model_overrides, dict):
if "model_overrides" in migrated:
migrated.pop("model_overrides", None)
return migrated, True
return migrated, False
if not existing_v2_override:
effective = resolve_effective_config(base_config, model_overrides)
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
mode="json", exclude_none=True
)
migrated.pop("model_overrides", None)
return migrated, True
def merge_ai_model_configuration_v2_secrets(
incoming: OrganizationAIModelConfigurationV2,
existing: OrganizationAIModelConfigurationV2 | None,
) -> OrganizationAIModelConfigurationV2:
if existing is None:
return incoming
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
existing_dict = existing.model_dump(mode="json", exclude_none=True)
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
incoming_dograh = incoming_dict.get("dograh") or {}
existing_dograh = existing_dict.get("dograh") or {}
incoming_key = incoming_dograh.get("api_key")
existing_key = existing_dograh.get("api_key")
if incoming_key and existing_key and contains_masked_key(incoming_key):
incoming_dograh["api_key"] = resolve_masked_api_keys(
incoming_key,
existing_key,
)
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
def check_for_masked_keys_in_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> None:
data = configuration.model_dump(mode="json", exclude_none=True)
_raise_if_masked_secret(data)
def mask_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2 | None,
) -> dict | None:
if configuration is None:
return None
data = configuration.model_dump(mode="json", exclude_none=True)
_mask_secret_fields(data)
return data
def convert_legacy_ai_model_configuration_to_v2(
configuration: EffectiveAIModelConfiguration,
) -> OrganizationAIModelConfigurationV2:
dograh_key = _first_dograh_api_key(configuration)
if dograh_key:
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
if configuration.is_realtime:
if configuration.realtime is None or configuration.llm is None:
raise ValueError("Realtime legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="realtime",
realtime=BYOKRealtimeAIModelConfiguration(
realtime=configuration.realtime,
llm=configuration.llm,
embeddings=configuration.embeddings,
),
),
)
if (
configuration.llm is None
or configuration.tts is None
or configuration.stt is None
):
raise ValueError("Pipeline legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="pipeline",
pipeline=BYOKPipelineAIModelConfiguration(
llm=configuration.llm,
tts=configuration.tts,
stt=configuration.stt,
embeddings=configuration.embeddings,
),
),
)
def dograh_embeddings_base_url() -> str:
return f"{MPS_API_URL}/api/v1/llm"
def apply_managed_embeddings_base_url(
*,
provider: str | None,
base_url: str | None,
) -> str | None:
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
return dograh_embeddings_base_url()
return base_url
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
return
incoming_mode = incoming_byok.get("mode")
existing_mode = existing_byok.get("mode")
if incoming_mode != existing_mode:
return
section_names = (
("llm", "tts", "stt", "embeddings")
if incoming_mode == "pipeline"
else ("realtime", "llm", "embeddings")
)
incoming_container = incoming_byok.get(incoming_mode)
existing_container = existing_byok.get(existing_mode)
if not isinstance(incoming_container, dict) or not isinstance(
existing_container, dict
):
return
for section_name in section_names:
incoming_section = incoming_container.get(section_name)
existing_section = existing_container.get(section_name)
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
_merge_service_secret_fields(incoming_section, existing_section)
async def _list_workflows_for_model_configuration_migration(
organization_id: int,
) -> list[WorkflowModel]:
async with db_client.async_session() as session:
result = await session.execute(
select(WorkflowModel)
.options(selectinload(WorkflowModel.definitions))
.where(WorkflowModel.organization_id == organization_id)
)
return list(result.scalars().unique().all())
def _merge_service_secret_fields(incoming: dict, existing: dict):
if (
incoming.get("provider") is not None
and existing.get("provider") is not None
and incoming.get("provider") != existing.get("provider")
):
return
for secret_field in SERVICE_SECRET_FIELDS:
if secret_field not in existing:
continue
incoming_secret = incoming.get(secret_field)
existing_secret = existing[secret_field]
if incoming_secret is None:
incoming[secret_field] = existing_secret
elif contains_masked_key(incoming_secret):
incoming[secret_field] = resolve_masked_api_keys(
incoming_secret,
existing_secret,
)
def _raise_if_masked_secret(value):
if isinstance(value, dict):
for key, nested in value.items():
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
raise ValueError(
f"The {key} appears to be masked. Please provide the actual "
"value, not the masked value."
)
_raise_if_masked_secret(nested)
elif isinstance(value, list):
for item in value:
_raise_if_masked_secret(item)
def _mask_secret_fields(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_secret_fields(nested)
elif isinstance(value, list):
for item in value:
_mask_secret_fields(item)
def _mask_secret_value(value):
if isinstance(value, list):
return [mask_key(item) for item in value]
return mask_key(value)
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
return any(
service is not None
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
)
)
def _convert_any_dograh_legacy_configuration(
configuration: EffectiveAIModelConfiguration,
dograh_key: str,
) -> OrganizationAIModelConfigurationV2:
speed = getattr(configuration.tts, "speed", 1.0)
if speed not in DOGRAH_SPEED_OPTIONS:
speed = 1.0
return OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key=dograh_key,
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
or DOGRAH_DEFAULT_VOICE,
speed=speed,
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
or DOGRAH_DEFAULT_LANGUAGE,
),
)
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
):
if service is None or _provider(service) != ServiceProviders.DOGRAH:
continue
try:
return _single_api_key(service)
except ValueError:
continue
return None
def _provider(service):
return getattr(service, "provider", None)
def _single_api_key(service) -> str:
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if len(keys) != 1:
raise ValueError("Expected exactly one API key")
return keys[0]
key = getattr(service, "api_key", None)
if not key:
raise ValueError("Expected an API key")
return key

View file

@ -151,21 +151,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
masked = copy.deepcopy(config)
model_overrides = masked.get("model_overrides")
if not isinstance(model_overrides, dict):
return masked
if isinstance(model_overrides, dict):
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
v2_override = masked.get("model_configuration_v2_override")
if isinstance(v2_override, dict):
_mask_nested_service_secrets(v2_override)
return masked
def _mask_nested_service_secrets(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_nested_service_secrets(nested)
elif isinstance(value, list):
for item in value:
_mask_nested_service_secrets(item)
# ---------------------------------------------------------------------------
# Workflow definition helpers mask / merge node API keys
# ---------------------------------------------------------------------------

View file

@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
)
DOGRAH_EMBEDDING_MODELS = ["default"]
@register_embeddings
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: str = Field(
default="default",
description="Dograh-managed embedding model.",
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
)
EmbeddingsConfig = Annotated[
Union[
OpenAIEmbeddingsConfiguration,
OpenRouterEmbeddingsConfiguration,
AzureOpenAIEmbeddingsConfiguration,
DograhEmbeddingsConfiguration,
],
Field(discriminator="provider"),
]

View file

@ -0,0 +1,62 @@
from inspect import isawaitable
from loguru import logger
from pydantic import ValidationError
from api.db import db_client
from api.enums import OrganizationConfigurationKey
from api.schemas.organization_preferences import OrganizationPreferences
async def get_organization_preferences(
organization_id: int | None,
db=None,
) -> OrganizationPreferences:
if organization_id is None:
return OrganizationPreferences()
db = db or db_client
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
)
if row is None:
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
return _parse_preferences(row.value if row is not None else None, organization_id)
async def upsert_organization_preferences(
organization_id: int,
preferences: OrganizationPreferences,
) -> OrganizationPreferences:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
preferences.model_dump(mode="json", exclude_none=True),
)
return preferences
async def _get_configuration(db, organization_id: int, key: str):
row = db.get_configuration(organization_id, key)
if isawaitable(row):
row = await row
return row
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
if not value or not isinstance(value, dict):
return OrganizationPreferences()
try:
return OrganizationPreferences.model_validate(value)
except ValidationError as exc:
logger.warning(
"Invalid organization preferences for organization "
f"{organization_id}: {exc}. Returning defaults."
)
return OrganizationPreferences()

View file

@ -195,14 +195,17 @@ async def run_pipeline_telephony(
# Resolve effective user config here so the transport can tune its
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
# values into _run_pipeline so it doesn't fetch them again.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -272,15 +275,18 @@ async def run_pipeline_smallwebrtc(
# Resolve workflow_run + effective user_config here so the transport can
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
# reuses these via kwargs so we don't fetch twice.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -380,11 +386,14 @@ async def _run_pipeline(
# Resolve model overrides from the version onto global user config (skip
# when the caller already resolved it).
if resolved_user_config is None:
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
else:
user_config = resolved_user_config
@ -508,10 +517,17 @@ async def _run_pipeline(
embeddings_endpoint = None
embeddings_api_version = None
if user_config and user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)

View file

@ -10,8 +10,10 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
from api.services.mps_service_key_client import mps_service_key_client
@ -48,17 +50,20 @@ async def check_dograh_quota(
if quota is insufficient.
"""
try:
# Get user configurations
user_config = await db_client.get_user_configurations(user.id)
organization_id = user.selected_organization_id
workflow_configurations = None
if workflow_id is not None:
workflow = await db_client.get_workflow_by_id(workflow_id)
if workflow:
model_overrides = (workflow.workflow_configurations or {}).get(
"model_overrides"
)
if model_overrides:
user_config = resolve_effective_config(user_config, model_overrides)
organization_id = workflow.organization_id
workflow_configurations = workflow.workflow_configurations
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user.id,
organization_id=organization_id,
workflow_configurations=workflow_configurations,
)
# Check if user is using any Dograh service
using_dograh = False
@ -76,6 +81,13 @@ async def check_dograh_quota(
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
if (
user_config.embeddings
and user_config.embeddings.provider == ServiceProviders.DOGRAH
):
using_dograh = True
dograh_api_keys.add(user_config.embeddings.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return QuotaCheckResult(has_quota=True)
@ -84,7 +96,9 @@ async def check_dograh_quota(
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key, created_by=user.provider_id
api_key,
organization_id=organization_id,
created_by=user.provider_id,
)
remaining = usage.get("remaining_credits", 0.0)

View file

@ -2,7 +2,6 @@
import random
from api.db import db_client
from api.db.models import WorkflowRunModel
from api.services.workflow.dto import QANodeData
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
llm_config: dict = {}
if user_id:
user_configuration = await db_client.get_user_configurations(user_id)
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_configurations = {}
if workflow_run.definition:
workflow_configurations = (
workflow_run.definition.workflow_configurations or {}
)
elif workflow_run.workflow:
workflow_configurations = (
workflow_run.workflow.workflow_configurations or {}
)
user_configuration = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow_run.workflow.organization_id
if workflow_run.workflow
else None,
workflow_configurations=workflow_configurations,
)
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
provider = llm_config.get("provider", "openai")

View file

@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
from api.db import db_client
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.configuration.resolve import resolve_effective_config
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.pipeline_builder import create_pipeline_task
from api.services.pipecat.pipeline_metrics_aggregator import (
@ -410,9 +409,14 @@ async def execute_text_chat_pending_turn(
run_definition = workflow_run.definition
run_configs = run_definition.workflow_configurations or {}
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_run.workflow.user.id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
@ -466,9 +470,17 @@ async def execute_text_chat_pending_turn(
embeddings_model = None
embeddings_base_url = None
if user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
context_compaction_enabled = (workflow.workflow_configurations or {}).get(

View file

@ -157,12 +157,24 @@ async def process_knowledge_base_document(
embeddings_endpoint = None
embeddings_api_version = None
if document.created_by:
user_config = await db_client.get_user_configurations(document.created_by)
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=document.created_by,
organization_id=document.organization_id,
)
user_config = resolved_config.effective
if user_config.embeddings:
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None

View file

@ -0,0 +1,295 @@
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
config = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
voice="default",
speed=1.2,
language="multi",
),
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is False
assert effective.llm.provider == "dograh"
assert effective.llm.model == "default"
assert effective.tts.provider == "dograh"
assert effective.tts.speed == 1.2
assert effective.stt.provider == "dograh"
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
def test_dograh_v2_rejects_non_predefined_speed():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
speed=1.5,
),
)
def test_byok_v2_rejects_dograh_provider():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
"tts": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
"voice": "default",
},
"stt": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
},
},
}
)
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
)
incoming = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
)
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
assert merged.dograh.api_key == "mps-real-secret"
check_for_masked_keys_in_ai_model_configuration_v2(merged)
def test_masked_v2_configuration_masks_nested_service_keys():
config = OrganizationAIModelConfigurationV2(
mode="byok",
byok={
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "openai",
"api_key": "sk-real-secret",
"model": "gpt-4.1",
},
"tts": {
"provider": "elevenlabs",
"api_key": "el-real-secret",
"model": "eleven_flash_v2_5",
"voice": "Rachel",
},
"stt": {
"provider": "deepgram",
"api_key": "dg-real-secret",
"model": "nova-3-general",
},
},
},
)
masked = mask_ai_model_configuration_v2(config)
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
legacy = UserConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
speed=1.0,
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
language="multi",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-secret"
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
legacy = UserConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=DograhTTSService(
provider="dograh",
api_key="mps-tts",
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key="mps-stt",
model="default",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-tts"
assert config.dograh.voice == "default"
def test_legacy_byok_pipeline_converts_to_byok_v2():
legacy = UserConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "byok"
assert config.byok.mode == "pipeline"
assert config.byok.pipeline.llm.provider == "openai"
assert config.byok.pipeline.tts.provider == "elevenlabs"
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
base = UserConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
)
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": {
"tts": {
"provider": "dograh",
"api_key": "mps-workflow",
"model": "default",
"voice": "default",
}
},
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
assert v2_override["mode"] == "dograh"
assert v2_override["dograh"]["api_key"] == "mps-workflow"
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
base = UserConfiguration()
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": None,
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}

View file

@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
@ -14,14 +15,14 @@ from api.services.configuration.registry import (
)
def _make_test_app():
def _make_test_app(selected_organization_id=None):
app = FastAPI()
app.include_router(router)
mock_user = MagicMock()
mock_user.id = 1
mock_user.is_superuser = False
mock_user.selected_organization_id = None
mock_user.selected_organization_id = selected_organization_id
app.dependency_overrides[get_user] = lambda: mock_user
return app
@ -210,3 +211,38 @@ class TestMaskedKeyRejection:
)
assert response.status_code == 200
def test_preference_only_update_does_not_validate_or_save_model_config(self):
"""Saving a test phone number through the legacy endpoint must not touch models."""
app = _make_test_app(selected_organization_id=11)
client = TestClient(app)
preferences = SimpleNamespace(test_phone_number=None, timezone=None)
with (
patch("api.routes.user.db_client") as mock_db,
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
patch(
"api.routes.user.get_organization_preferences",
new=AsyncMock(return_value=preferences),
),
patch(
"api.routes.user.upsert_organization_preferences",
new=AsyncMock(return_value=preferences),
) as upsert_preferences,
):
existing = _existing_openai_config()
mock_db.get_user_configurations = AsyncMock(return_value=existing)
mock_db.update_user_configuration = AsyncMock()
mock_db.get_organization_by_id = AsyncMock(return_value=None)
mock_validator.return_value.validate = AsyncMock()
response = client.put(
"/user/configurations/user",
json={"test_phone_number": "+15551234567"},
)
assert response.status_code == 200
assert response.json()["test_phone_number"] == "+15551234567"
mock_db.update_user_configuration.assert_not_called()
mock_validator.return_value.validate.assert_not_called()
upsert_preferences.assert_awaited_once()

View file

@ -103,6 +103,61 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
assert initiate_kwargs["workflow_id"] == workflow.id
assert initiate_kwargs["user_id"] == workflow.user_id
assert "user_id=99" in initiate_kwargs["webhook_url"]
mock_db.get_user_configurations.assert_not_called()
def test_initiate_call_uses_organization_preference_phone_number():
app = _make_test_app()
client = TestClient(app)
workflow = _workflow()
provider = _provider()
quota_mock = AsyncMock(
return_value=SimpleNamespace(has_quota=True, error_message="")
)
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
new=quota_mock,
),
patch(
"api.routes.telephony.get_default_telephony_provider",
new=AsyncMock(return_value=provider),
),
patch(
"api.routes.telephony.get_backend_endpoints",
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
),
):
mock_db.get_user_configurations = AsyncMock(
return_value=SimpleNamespace(test_phone_number="+15550000000")
)
mock_db.get_configuration = Mock(
return_value=SimpleNamespace(value={"test_phone_number": "+15557654321"})
)
mock_db.get_default_telephony_configuration = AsyncMock(
return_value=SimpleNamespace(id=55)
)
mock_db.get_workflow = AsyncMock(return_value=workflow)
mock_db.create_workflow_run = AsyncMock(
return_value=SimpleNamespace(
id=501,
name="WR-TEL-OUT-00000001",
initial_context={},
)
)
mock_db.update_workflow_run = AsyncMock()
response = client.post(
"/telephony/initiate-call",
json={"workflow_id": workflow.id},
)
assert response.status_code == 200
assert provider.initiate_call.await_args.kwargs["to_number"] == "+15557654321"
mock_db.get_user_configurations.assert_not_called()
def test_initiate_call_rejects_existing_run_for_different_workflow():

View file

@ -51,6 +51,38 @@ async def _create_user_and_workflow(
return user, workflow
@pytest.mark.asyncio
async def test_text_chat_session_creation_requires_selected_organization():
from httpx import ASGITransport, AsyncClient
from api.app import app
from api.services.auth.depends import get_user
user = UserModel(provider_id="textchat-user-no-selected-org")
async def mock_get_user():
return user
original_override = app.dependency_overrides.get(get_user)
app.dependency_overrides[get_user] = mock_get_user
try:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/api/v1/workflow/123/text-chat/sessions", json={}
)
finally:
if original_override:
app.dependency_overrides[get_user] = original_override
else:
app.dependency_overrides.pop(get_user, None)
assert response.status_code == 400
assert response.json() == {"detail": "No organization selected"}
@pytest.mark.asyncio
async def test_text_chat_session_creation_executes_initial_assistant_turn(
db_session,