feat: add model config v2

This commit is contained in:
Abhishek Kumar 2026-06-09 14:57:21 +05:30
parent 49e68b49d5
commit 94686b73c4
29 changed files with 4680 additions and 171 deletions

View file

@ -10,6 +10,7 @@ from sqlalchemy.orm import joinedload
from api.db.base_client import BaseDBClient
from api.db.filters import apply_workflow_run_filters
from api.db.models import (
OrganizationConfigurationModel,
OrganizationModel,
OrganizationUsageCycleModel,
UserConfigurationModel,
@ -17,6 +18,7 @@ from api.db.models import (
WorkflowModel,
WorkflowRunModel,
)
from api.enums import OrganizationConfigurationKey
from api.schemas.user_configuration import UserConfiguration
@ -440,8 +442,19 @@ class OrganizationUsageClient(BaseDBClient):
"""Get daily usage breakdown for an organization with pricing."""
async with self.async_session() as session:
# Get user timezone if user_id is provided
# Get org timezone preference first, then fall back to legacy user config.
user_timezone = "UTC" # Default timezone
pref_result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key
== OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
)
pref_obj = pref_result.scalar_one_or_none()
if pref_obj and pref_obj.value:
user_timezone = pref_obj.value.get("timezone") or user_timezone
if user_id:
config_result = await session.execute(
select(UserConfigurationModel).where(
@ -453,7 +466,7 @@ class OrganizationUsageClient(BaseDBClient):
user_config = UserConfiguration.model_validate(
config_obj.configuration
)
if user_config.timezone:
if user_config.timezone and user_timezone == "UTC":
user_timezone = user_config.timezone
# Validate timezone string

View file

@ -89,6 +89,12 @@ class OrganizationConfigurationKey(Enum):
LANGFUSE_CREDENTIALS = (
"LANGFUSE_CREDENTIALS" # Org-level Langfuse tracing credentials
)
MODEL_CONFIGURATION_V2 = (
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
)
MODEL_CONFIGURATION_PREFERENCES = (
"MODEL_CONFIGURATION_PREFERENCES" # Org-level model configuration preferences
)
class WorkflowStatus(Enum):

View file

@ -3,9 +3,12 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.enums import PostHogEvent
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
from api.services.posthog_client import capture_event
from api.utils.auth import create_jwt_token, hash_password, verify_password
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
)
if mps_config:
await db_client.update_user_configuration(user.id, mps_config)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception:
logger.warning(
"Failed to create default configuration for OSS user", exc_info=True

View file

@ -369,6 +369,10 @@ async def search_chunks(
try:
# Import here to avoid circular dependency
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
from api.services.configuration.registry import ServiceProviders
from api.services.gen_ai import (
AzureOpenAIEmbeddingService,
@ -376,10 +380,15 @@ async def search_chunks(
)
# Try to get user's embeddings configuration
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
embeddings_api_key = None
embeddings_model = None
embeddings_provider = None
embeddings_base_url = None
embeddings_endpoint = None
embeddings_api_version = None
@ -388,6 +397,10 @@ async def search_chunks(
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None
)
@ -406,9 +419,7 @@ async def search_chunks(
db_client=db_client,
api_key=embeddings_api_key,
model_id=embeddings_model or "text-embedding-3-small",
base_url=getattr(user_config.embeddings, "base_url", None)
if user_config.embeddings
else None,
base_url=embeddings_base_url,
)
# Perform search

View file

@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
@ -10,6 +10,14 @@ from api.db import db_client
from api.db.models import UserModel
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
from api.enums import OrganizationConfigurationKey, PostHogEvent
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
OrganizationAIModelConfigurationPreferences,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
)
from api.schemas.telephony_config import (
TelephonyConfigRequest,
TelephonyConfigurationCreateRequest,
@ -27,7 +35,28 @@ from api.schemas.telephony_phone_number import (
ProviderSyncStatus,
)
from api.services.auth.depends import get_user
from api.services.configuration.masking import is_mask_of, mask_key
from api.services.configuration.ai_model_configuration import (
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_organization_ai_model_configuration_preferences,
get_organization_ai_model_configuration_v2,
get_resolved_ai_model_configuration,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_model_configurations_to_v2,
upsert_organization_ai_model_configuration_preferences,
upsert_organization_ai_model_configuration_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
from api.services.configuration.registry import (
DOGRAH_STT_LANGUAGES,
REGISTRY,
ServiceProviders,
ServiceType,
)
from api.services.posthog_client import capture_event
from api.services.telephony import registry as telephony_registry
from api.services.telephony.factory import get_telephony_provider_by_id
@ -159,6 +188,208 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
)
# ---------------------------------------------------------------------------
# AI model configurations v2
# ---------------------------------------------------------------------------
def _require_selected_organization(user: UserModel) -> int:
if not user.selected_organization_id:
raise HTTPException(status_code=400, detail="No organization selected")
return user.selected_organization_id
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
return {
provider: model_cls.model_json_schema()
for provider, model_cls in REGISTRY[service_type].items()
if provider != ServiceProviders.DOGRAH.value
}
async def _model_configuration_v2_response(
*,
user: UserModel,
configuration: OrganizationAIModelConfigurationV2 | None = None,
) -> OrganizationAIModelConfigurationResponse:
resolved = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
raw_configuration = (
configuration
if configuration is not None
else resolved.organization_configuration
)
return OrganizationAIModelConfigurationResponse(
configuration=mask_ai_model_configuration_v2(raw_configuration),
effective_configuration=mask_user_config(resolved.effective),
preferences=resolved.preferences
or OrganizationAIModelConfigurationPreferences(),
source=resolved.source,
)
@router.get("/model-configurations/v2/defaults")
async def get_model_configuration_v2_defaults(user: UserModel = Depends(get_user)):
_require_selected_organization(user)
byok_default_providers = {
service: provider
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
if provider != ServiceProviders.DOGRAH.value
}
return {
"dograh": {
"voices": [DOGRAH_DEFAULT_VOICE],
"speeds": list(DOGRAH_SPEED_OPTIONS),
"languages": DOGRAH_STT_LANGUAGES,
"defaults": {
"voice": DOGRAH_DEFAULT_VOICE,
"speed": 1.0,
"language": DOGRAH_DEFAULT_LANGUAGE,
},
},
"byok": {
"pipeline": {
"llm": _byok_provider_schemas(ServiceType.LLM),
"tts": _byok_provider_schemas(ServiceType.TTS),
"stt": _byok_provider_schemas(ServiceType.STT),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
"realtime": {
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
"llm": _byok_provider_schemas(ServiceType.LLM),
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
"default_providers": byok_default_providers,
},
},
}
@router.get(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def get_model_configuration_v2(user: UserModel = Depends(get_user)):
_require_selected_organization(user)
return await _model_configuration_v2_response(user=user)
@router.put(
"/model-configurations/v2",
response_model=OrganizationAIModelConfigurationResponse,
)
async def save_model_configuration_v2(
request: OrganizationAIModelConfigurationV2,
user: UserModel = Depends(get_user),
):
organization_id = _require_selected_organization(user)
existing = await get_organization_ai_model_configuration_v2(organization_id)
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
try:
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get("/model-configurations/v2/migration-preview")
async def preview_model_configuration_v2_migration(user: UserModel = Depends(get_user)):
_require_selected_organization(user)
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc))
return {
"configuration": mask_ai_model_configuration_v2(configuration),
"effective_configuration": mask_user_config(
compile_ai_model_configuration_v2(configuration)
),
}
@router.post(
"/model-configurations/v2/migrate",
response_model=OrganizationAIModelConfigurationResponse,
)
async def migrate_model_configuration_v2(
force: bool = Query(default=False),
user: UserModel = Depends(get_user),
):
organization_id = _require_selected_organization(user)
existing = await get_organization_ai_model_configuration_v2(organization_id)
if existing is not None and not force:
raise HTTPException(
status_code=409,
detail="Organization already has a v2 model configuration",
)
legacy = await db_client.get_user_configurations(user.id)
try:
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
effective = compile_ai_model_configuration_v2(configuration)
await UserConfigurationValidator().validate(
effective,
organization_id=organization_id,
created_by=user.provider_id,
)
except ValueError as exc:
raise HTTPException(status_code=422, detail=exc.args[0])
await upsert_organization_ai_model_configuration_v2(
organization_id,
configuration,
)
await migrate_workflow_model_configurations_to_v2(
organization_id=organization_id,
fallback_user_config=legacy,
)
return await _model_configuration_v2_response(
user=user,
configuration=configuration,
)
@router.get(
"/model-configurations/preferences",
response_model=OrganizationAIModelConfigurationPreferences,
)
async def get_model_configuration_preferences(user: UserModel = Depends(get_user)):
organization_id = _require_selected_organization(user)
return await get_organization_ai_model_configuration_preferences(organization_id)
@router.put(
"/model-configurations/preferences",
response_model=OrganizationAIModelConfigurationPreferences,
)
async def save_model_configuration_preferences(
request: OrganizationAIModelConfigurationPreferences,
user: UserModel = Depends(get_user),
):
organization_id = _require_selected_organization(user)
return await upsert_organization_ai_model_configuration_preferences(
organization_id,
request,
)
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
"""If the client re-submitted a masked sensitive field, restore the original."""
for field_name in _sensitive_fields(provider):

View file

@ -82,7 +82,15 @@ async def initiate_call(
"""Initiate a call using the configured telephony provider from web browser. This is
supposed to be a test call method for the draft version of the agent."""
from api.services.configuration.ai_model_configuration import (
get_organization_ai_model_configuration_preferences,
)
user_configuration = await db_client.get_user_configurations(user.id)
preferences = await get_organization_ai_model_configuration_preferences(
user.selected_organization_id,
db=db_client,
)
# Resolve which telephony config to use: explicit request value, otherwise
# the org's default outbound config.
@ -116,7 +124,11 @@ async def initiate_call(
detail="telephony_not_configured",
)
phone_number = request.phone_number or user_configuration.test_phone_number
phone_number = (
request.phone_number
or preferences.test_phone_number
or user_configuration.test_phone_number
)
if not phone_number:
raise HTTPException(

View file

@ -10,6 +10,11 @@ from api.db.models import (
UserModel,
)
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
get_organization_ai_model_configuration_preferences,
get_resolved_ai_model_configuration,
upsert_organization_ai_model_configuration_preferences,
)
from api.services.configuration.check_validity import (
APIKeyStatusResponse,
UserConfigurationValidator,
@ -91,8 +96,18 @@ class UserConfigurationRequestResponseSchema(BaseModel):
async def get_user_configurations(
user: UserModel = Depends(get_user),
) -> UserConfigurationRequestResponseSchema:
user_configurations = await db_client.get_user_configurations(user.id)
masked_config = mask_user_config(user_configurations)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
masked_config = mask_user_config(resolved_config.effective)
if resolved_config.preferences:
if resolved_config.preferences.test_phone_number is not None:
masked_config["test_phone_number"] = (
resolved_config.preferences.test_phone_number
)
if resolved_config.preferences.timezone is not None:
masked_config["timezone"] = resolved_config.preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -144,8 +159,31 @@ async def update_user_configurations(
user.id, user_configurations
)
if user.selected_organization_id and (
request.test_phone_number is not None or request.timezone is not None
):
preferences = await get_organization_ai_model_configuration_preferences(
user.selected_organization_id
)
if request.test_phone_number is not None:
preferences.test_phone_number = request.test_phone_number
if request.timezone is not None:
preferences.timezone = request.timezone
await upsert_organization_ai_model_configuration_preferences(
user.selected_organization_id,
preferences,
)
# Return masked version of updated config
masked_config = mask_user_config(user_configurations)
if user.selected_organization_id:
preferences = await get_organization_ai_model_configuration_preferences(
user.selected_organization_id
)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -165,7 +203,11 @@ async def validate_user_configurations(
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
user: UserModel = Depends(get_user),
) -> APIKeyStatusResponse:
configurations = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
configurations = resolved_config.effective
if (
configurations.last_validated_at

View file

@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
from api.db.models import UserModel
from api.db.workflow_template_client import WorkflowTemplateClient
from api.enums import CallType, PostHogEvent, StorageBackend
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
from api.schemas.workflow import WorkflowRunResponseSchema
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_resolved_ai_model_configuration,
merge_ai_model_configuration_v2_secrets,
)
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.masking import (
mask_workflow_configurations,
@ -955,12 +964,74 @@ async def update_workflow(
existing_def,
)
# Validate model_overrides: resolve onto global config, then
# run the same validator used by the user-configurations endpoint.
# Also stamp the current global API key into the override so the override
# remains functional if the global config later switches to a different provider.
# Validate model overrides. v2 uses a complete workflow-level model
# configuration; legacy v1 uses partial service overlays.
workflow_configurations = request.workflow_configurations
if workflow_configurations and workflow_configurations.get("model_overrides"):
if workflow_configurations and workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
if existing_workflow is None:
raise HTTPException(
status_code=404, detail=f"Workflow with id {workflow_id} not found"
)
existing_draft = await db_client.get_draft_version(workflow_id)
existing_configs = (
existing_draft.workflow_configurations
if existing_draft
else existing_workflow.released_definition.workflow_configurations
)
existing_v2_override = (existing_configs or {}).get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
try:
incoming_v2_override = (
OrganizationAIModelConfigurationV2.model_validate(
workflow_configurations[
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
]
)
)
existing_v2_override_config = (
OrganizationAIModelConfigurationV2.model_validate(
existing_v2_override
)
if existing_v2_override
else None
)
v2_override = merge_ai_model_configuration_v2_secrets(
incoming_v2_override,
existing_v2_override_config,
)
if existing_v2_override_config is None:
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
v2_override = merge_ai_model_configuration_v2_secrets(
v2_override,
resolved_config.organization_configuration,
)
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
effective = compile_ai_model_configuration_v2(v2_override)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except (ValidationError, ValueError) as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
elif workflow_configurations and workflow_configurations.get("model_overrides"):
existing_workflow = await db_client.get_workflow(
workflow_id, organization_id=user.selected_organization_id
)
@ -978,24 +1049,46 @@ async def update_workflow(
workflow_configurations,
existing_configs,
)
user_config = await db_client.get_user_configurations(user.id)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user.id,
organization_id=user.selected_organization_id,
)
user_config = resolved_config.effective
try:
enriched_overrides = enrich_overrides_with_api_keys(
workflow_configurations["model_overrides"],
user_config,
)
effective = resolve_effective_config(user_config, enriched_overrides)
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
if resolved_config.source == "organization_v2":
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
await UserConfigurationValidator().validate(
compile_ai_model_configuration_v2(v2_override),
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
else:
await UserConfigurationValidator().validate(
effective,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
if resolved_config.source == "organization_v2":
workflow_configurations = {
**workflow_configurations,
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
mode="json",
exclude_none=True,
),
}
workflow_configurations.pop("model_overrides", None)
else:
workflow_configurations = {
**workflow_configurations,
"model_overrides": enriched_overrides,
}
# Reject upfront if any new trigger path collides with another
# workflow's trigger — keeps the workflow record from

View file

@ -0,0 +1,176 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.registry import (
DograhEmbeddingsConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
EmbeddingsConfig,
LLMConfig,
RealtimeConfig,
ServiceProviders,
STTConfig,
TTSConfig,
)
DOGRAH_SPEED_OPTIONS: tuple[float, ...] = (0.8, 1.0, 1.2)
DOGRAH_DEFAULT_VOICE = "default"
DOGRAH_DEFAULT_LANGUAGE = "multi"
class DograhManagedAIModelConfiguration(BaseModel):
api_key: str
voice: str = DOGRAH_DEFAULT_VOICE
speed: float = Field(default=1.0)
language: str = DOGRAH_DEFAULT_LANGUAGE
@model_validator(mode="after")
def validate_speed(self):
if self.speed not in DOGRAH_SPEED_OPTIONS:
allowed = ", ".join(str(speed) for speed in DOGRAH_SPEED_OPTIONS)
raise ValueError(f"Dograh speed must be one of: {allowed}")
return self
class BYOKPipelineAIModelConfiguration(BaseModel):
llm: LLMConfig
tts: TTSConfig
stt: STTConfig
embeddings: EmbeddingsConfig | None = None
@model_validator(mode="after")
def reject_dograh_providers(self):
_reject_dograh_provider("llm", self.llm)
_reject_dograh_provider("tts", self.tts)
_reject_dograh_provider("stt", self.stt)
_reject_dograh_provider("embeddings", self.embeddings)
return self
class BYOKRealtimeAIModelConfiguration(BaseModel):
realtime: RealtimeConfig
llm: LLMConfig
embeddings: EmbeddingsConfig | None = None
@model_validator(mode="after")
def reject_dograh_providers(self):
_reject_dograh_provider("llm", self.llm)
_reject_dograh_provider("embeddings", self.embeddings)
return self
class BYOKAIModelConfiguration(BaseModel):
mode: Literal["pipeline", "realtime"]
pipeline: BYOKPipelineAIModelConfiguration | None = None
realtime: BYOKRealtimeAIModelConfiguration | None = None
@model_validator(mode="after")
def validate_selected_mode(self):
if self.mode == "pipeline" and self.pipeline is None:
raise ValueError("byok.pipeline is required when byok.mode is pipeline")
if self.mode == "realtime" and self.realtime is None:
raise ValueError("byok.realtime is required when byok.mode is realtime")
return self
class OrganizationAIModelConfigurationV2(BaseModel):
version: Literal[2] = 2
mode: Literal["dograh", "byok"]
dograh: DograhManagedAIModelConfiguration | None = None
byok: BYOKAIModelConfiguration | None = None
@model_validator(mode="after")
def validate_selected_mode(self):
if self.mode == "dograh" and self.dograh is None:
raise ValueError("dograh configuration is required when mode is dograh")
if self.mode == "byok" and self.byok is None:
raise ValueError("byok configuration is required when mode is byok")
return self
class OrganizationAIModelConfigurationPreferences(BaseModel):
test_phone_number: str | None = None
timezone: str | None = None
class OrganizationAIModelConfigurationResponse(BaseModel):
configuration: dict | None
effective_configuration: dict
preferences: OrganizationAIModelConfigurationPreferences
source: Literal["organization_v2", "legacy_user_v1", "empty"]
def compile_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> EffectiveAIModelConfiguration:
if configuration.mode == "dograh":
if configuration.dograh is None:
raise ValueError("dograh configuration is required")
return _compile_dograh_configuration(configuration.dograh)
if configuration.byok is None:
raise ValueError("byok configuration is required")
if configuration.byok.mode == "pipeline":
if configuration.byok.pipeline is None:
raise ValueError("byok.pipeline is required")
pipeline = configuration.byok.pipeline
return EffectiveAIModelConfiguration(
llm=pipeline.llm,
tts=pipeline.tts,
stt=pipeline.stt,
embeddings=pipeline.embeddings,
is_realtime=False,
)
if configuration.byok.realtime is None:
raise ValueError("byok.realtime is required")
realtime = configuration.byok.realtime
return EffectiveAIModelConfiguration(
llm=realtime.llm,
realtime=realtime.realtime,
embeddings=realtime.embeddings,
is_realtime=True,
)
def _compile_dograh_configuration(
configuration: DograhManagedAIModelConfiguration,
) -> EffectiveAIModelConfiguration:
return EffectiveAIModelConfiguration(
llm=DograhLLMService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
),
tts=DograhTTSService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
voice=configuration.voice,
speed=configuration.speed,
),
stt=DograhSTTService(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
language=configuration.language,
),
embeddings=DograhEmbeddingsConfiguration(
provider=ServiceProviders.DOGRAH,
api_key=configuration.api_key,
model="default",
),
is_realtime=False,
)
def _reject_dograh_provider(section: str, service) -> None:
if service is None:
return
if getattr(service, "provider", None) == ServiceProviders.DOGRAH:
raise ValueError(f"BYOK {section} cannot use Dograh provider")

View file

@ -11,7 +11,7 @@ from api.services.configuration.registry import (
)
class UserConfiguration(BaseModel):
class EffectiveAIModelConfiguration(BaseModel):
llm: LLMConfig | None = None
stt: STTConfig | None = None
tts: TTSConfig | None = None
@ -31,3 +31,7 @@ class UserConfiguration(BaseModel):
if isinstance(realtime, dict) and not realtime.get("api_key"):
data.pop("realtime", None)
return data
# Backward-compatible alias for legacy persistence and existing call sites.
UserConfiguration = EffectiveAIModelConfiguration

View file

@ -119,6 +119,19 @@ async def get_user(
await db_client.update_user_configuration(
user_model.id, mps_config
)
from api.enums import OrganizationConfigurationKey
from api.services.configuration.ai_model_configuration import (
convert_legacy_ai_model_configuration_to_v2,
)
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
mps_config
)
await db_client.upsert_configuration(
organization.id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
model_config_v2.model_dump(mode="json", exclude_none=True),
)
except Exception as exc:
raise HTTPException(

View file

@ -0,0 +1,532 @@
from __future__ import annotations
import copy
from dataclasses import dataclass
from inspect import isawaitable
from typing import Literal
from loguru import logger
from pydantic import ValidationError
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
from api.constants import MPS_API_URL
from api.db import db_client
from api.db.models import WorkflowDefinitionModel, WorkflowModel
from api.enums import OrganizationConfigurationKey
from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
BYOKAIModelConfiguration,
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
OrganizationAIModelConfigurationPreferences,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import EffectiveAIModelConfiguration
from api.services.configuration.masking import (
SERVICE_SECRET_FIELDS,
contains_masked_key,
mask_key,
resolve_masked_api_keys,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
@dataclass
class ResolvedAIModelConfiguration:
effective: EffectiveAIModelConfiguration
source: AIModelConfigurationSource
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
preferences: OrganizationAIModelConfigurationPreferences | None = None
@dataclass
class WorkflowAIModelConfigurationMigrationResult:
workflow_count: int = 0
definition_count: int = 0
workflow_ids: list[int] | None = None
async def get_resolved_ai_model_configuration(
*,
user_id: int | None,
organization_id: int | None,
) -> ResolvedAIModelConfiguration:
preferences = await get_organization_ai_model_configuration_preferences(
organization_id
)
organization_configuration = await get_organization_ai_model_configuration_v2(
organization_id
)
if organization_configuration is not None:
return ResolvedAIModelConfiguration(
effective=compile_ai_model_configuration_v2(organization_configuration),
source="organization_v2",
organization_configuration=organization_configuration,
preferences=preferences,
)
if user_id is None:
return ResolvedAIModelConfiguration(
effective=EffectiveAIModelConfiguration(),
source="empty",
preferences=preferences,
)
legacy = await db_client.get_user_configurations(user_id)
return ResolvedAIModelConfiguration(
effective=legacy,
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
preferences=preferences,
)
async def get_effective_ai_model_configuration_for_workflow(
*,
user_id: int | None,
organization_id: int | None,
workflow_configurations: dict | None,
) -> EffectiveAIModelConfiguration:
workflow_configurations = workflow_configurations or {}
v2_override = workflow_configurations.get(
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
)
if v2_override:
return compile_ai_model_configuration_v2(
OrganizationAIModelConfigurationV2.model_validate(v2_override)
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=user_id,
organization_id=organization_id,
)
return resolve_effective_config(
resolved_config.effective,
workflow_configurations.get("model_overrides"),
)
async def get_organization_ai_model_configuration_v2(
organization_id: int | None,
) -> OrganizationAIModelConfigurationV2 | None:
if organization_id is None:
return None
row = await db_client.get_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
)
if row is None or not row.value:
return None
try:
return OrganizationAIModelConfigurationV2.model_validate(row.value)
except ValidationError as exc:
logger.warning(
"Invalid org AI model configuration v2 for organization "
f"{organization_id}: {exc}. Falling back to legacy configuration."
)
return None
async def get_organization_ai_model_configuration_preferences(
organization_id: int | None,
db=None,
) -> OrganizationAIModelConfigurationPreferences:
if organization_id is None:
return OrganizationAIModelConfigurationPreferences()
db = db or db_client
row = db.get_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
if isawaitable(row):
row = await row
if row is None or not row.value:
return OrganizationAIModelConfigurationPreferences()
if not isinstance(row.value, dict):
return OrganizationAIModelConfigurationPreferences()
try:
return OrganizationAIModelConfigurationPreferences.model_validate(row.value)
except ValidationError as exc:
logger.warning(
"Invalid org AI model configuration preferences for organization "
f"{organization_id}: {exc}. Returning defaults."
)
return OrganizationAIModelConfigurationPreferences()
async def upsert_organization_ai_model_configuration_v2(
organization_id: int,
configuration: OrganizationAIModelConfigurationV2,
) -> OrganizationAIModelConfigurationV2:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
configuration.model_dump(mode="json", exclude_none=True),
)
return configuration
async def upsert_organization_ai_model_configuration_preferences(
organization_id: int,
preferences: OrganizationAIModelConfigurationPreferences,
) -> OrganizationAIModelConfigurationPreferences:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
preferences.model_dump(mode="json", exclude_none=True),
)
return preferences
async def migrate_workflow_model_configurations_to_v2(
*,
organization_id: int,
fallback_user_config: EffectiveAIModelConfiguration,
) -> WorkflowAIModelConfigurationMigrationResult:
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
workflow_updates: list[tuple[int, dict]] = []
definition_updates: list[tuple[int, dict]] = []
migrated_workflow_ids: set[int] = set()
for workflow in workflows:
base_config = fallback_user_config
if workflow.user_id is not None:
if workflow.user_id not in owner_configs:
owner_configs[
workflow.user_id
] = await db_client.get_user_configurations(workflow.user_id)
base_config = owner_configs[workflow.user_id]
workflow_configs, workflow_changed = (
migrate_workflow_configuration_model_override_to_v2(
workflow.workflow_configurations,
base_config,
)
)
if workflow_changed:
workflow_updates.append((workflow.id, workflow_configs))
migrated_workflow_ids.add(workflow.id)
for definition in workflow.definitions:
definition_configs, definition_changed = (
migrate_workflow_configuration_model_override_to_v2(
definition.workflow_configurations,
base_config,
)
)
if definition_changed:
definition_updates.append((definition.id, definition_configs))
migrated_workflow_ids.add(workflow.id)
if workflow_updates or definition_updates:
async with db_client.async_session() as session:
for workflow_id, workflow_configs in workflow_updates:
await session.execute(
update(WorkflowModel)
.where(WorkflowModel.id == workflow_id)
.values(workflow_configurations=workflow_configs)
)
for definition_id, definition_configs in definition_updates:
await session.execute(
update(WorkflowDefinitionModel)
.where(WorkflowDefinitionModel.id == definition_id)
.values(workflow_configurations=definition_configs)
)
await session.commit()
return WorkflowAIModelConfigurationMigrationResult(
workflow_count=len(migrated_workflow_ids),
definition_count=len(definition_updates),
workflow_ids=sorted(migrated_workflow_ids),
)
def migrate_workflow_configuration_model_override_to_v2(
workflow_configurations: dict | None,
base_config: EffectiveAIModelConfiguration,
) -> tuple[dict, bool]:
if not isinstance(workflow_configurations, dict):
return {}, False
migrated = copy.deepcopy(workflow_configurations)
model_overrides = migrated.get("model_overrides")
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
if not isinstance(model_overrides, dict):
if "model_overrides" in migrated:
migrated.pop("model_overrides", None)
return migrated, True
return migrated, False
if not existing_v2_override:
effective = resolve_effective_config(base_config, model_overrides)
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
mode="json", exclude_none=True
)
migrated.pop("model_overrides", None)
return migrated, True
def merge_ai_model_configuration_v2_secrets(
incoming: OrganizationAIModelConfigurationV2,
existing: OrganizationAIModelConfigurationV2 | None,
) -> OrganizationAIModelConfigurationV2:
if existing is None:
return incoming
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
existing_dict = existing.model_dump(mode="json", exclude_none=True)
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
incoming_dograh = incoming_dict.get("dograh") or {}
existing_dograh = existing_dict.get("dograh") or {}
incoming_key = incoming_dograh.get("api_key")
existing_key = existing_dograh.get("api_key")
if incoming_key and existing_key and contains_masked_key(incoming_key):
incoming_dograh["api_key"] = resolve_masked_api_keys(
incoming_key,
existing_key,
)
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
def check_for_masked_keys_in_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2,
) -> None:
data = configuration.model_dump(mode="json", exclude_none=True)
_raise_if_masked_secret(data)
def mask_ai_model_configuration_v2(
configuration: OrganizationAIModelConfigurationV2 | None,
) -> dict | None:
if configuration is None:
return None
data = configuration.model_dump(mode="json", exclude_none=True)
_mask_secret_fields(data)
return data
def convert_legacy_ai_model_configuration_to_v2(
configuration: EffectiveAIModelConfiguration,
) -> OrganizationAIModelConfigurationV2:
dograh_key = _first_dograh_api_key(configuration)
if dograh_key:
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
if configuration.is_realtime:
if configuration.realtime is None or configuration.llm is None:
raise ValueError("Realtime legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="realtime",
realtime=BYOKRealtimeAIModelConfiguration(
realtime=configuration.realtime,
llm=configuration.llm,
embeddings=configuration.embeddings,
),
),
)
if (
configuration.llm is None
or configuration.tts is None
or configuration.stt is None
):
raise ValueError("Pipeline legacy configuration is incomplete")
return OrganizationAIModelConfigurationV2(
mode="byok",
byok=BYOKAIModelConfiguration(
mode="pipeline",
pipeline=BYOKPipelineAIModelConfiguration(
llm=configuration.llm,
tts=configuration.tts,
stt=configuration.stt,
embeddings=configuration.embeddings,
),
),
)
def dograh_embeddings_base_url() -> str:
return f"{MPS_API_URL}/api/v1/llm"
def apply_managed_embeddings_base_url(
*,
provider: str | None,
base_url: str | None,
) -> str | None:
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
return dograh_embeddings_base_url()
return base_url
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
return
incoming_mode = incoming_byok.get("mode")
existing_mode = existing_byok.get("mode")
if incoming_mode != existing_mode:
return
section_names = (
("llm", "tts", "stt", "embeddings")
if incoming_mode == "pipeline"
else ("realtime", "llm", "embeddings")
)
incoming_container = incoming_byok.get(incoming_mode)
existing_container = existing_byok.get(existing_mode)
if not isinstance(incoming_container, dict) or not isinstance(
existing_container, dict
):
return
for section_name in section_names:
incoming_section = incoming_container.get(section_name)
existing_section = existing_container.get(section_name)
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
_merge_service_secret_fields(incoming_section, existing_section)
async def _list_workflows_for_model_configuration_migration(
organization_id: int,
) -> list[WorkflowModel]:
async with db_client.async_session() as session:
result = await session.execute(
select(WorkflowModel)
.options(selectinload(WorkflowModel.definitions))
.where(WorkflowModel.organization_id == organization_id)
)
return list(result.scalars().unique().all())
def _merge_service_secret_fields(incoming: dict, existing: dict):
if (
incoming.get("provider") is not None
and existing.get("provider") is not None
and incoming.get("provider") != existing.get("provider")
):
return
for secret_field in SERVICE_SECRET_FIELDS:
if secret_field not in existing:
continue
incoming_secret = incoming.get(secret_field)
existing_secret = existing[secret_field]
if incoming_secret is None:
incoming[secret_field] = existing_secret
elif contains_masked_key(incoming_secret):
incoming[secret_field] = resolve_masked_api_keys(
incoming_secret,
existing_secret,
)
def _raise_if_masked_secret(value):
if isinstance(value, dict):
for key, nested in value.items():
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
raise ValueError(
f"The {key} appears to be masked. Please provide the actual "
"value, not the masked value."
)
_raise_if_masked_secret(nested)
elif isinstance(value, list):
for item in value:
_raise_if_masked_secret(item)
def _mask_secret_fields(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_secret_fields(nested)
elif isinstance(value, list):
for item in value:
_mask_secret_fields(item)
def _mask_secret_value(value):
if isinstance(value, list):
return [mask_key(item) for item in value]
return mask_key(value)
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
return any(
service is not None
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
)
)
def _convert_any_dograh_legacy_configuration(
configuration: EffectiveAIModelConfiguration,
dograh_key: str,
) -> OrganizationAIModelConfigurationV2:
speed = getattr(configuration.tts, "speed", 1.0)
if speed not in DOGRAH_SPEED_OPTIONS:
speed = 1.0
return OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key=dograh_key,
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
or DOGRAH_DEFAULT_VOICE,
speed=speed,
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
or DOGRAH_DEFAULT_LANGUAGE,
),
)
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
for service in (
configuration.llm,
configuration.tts,
configuration.stt,
configuration.embeddings,
configuration.realtime,
):
if service is None or _provider(service) != ServiceProviders.DOGRAH:
continue
try:
return _single_api_key(service)
except ValueError:
continue
return None
def _provider(service):
return getattr(service, "provider", None)
def _single_api_key(service) -> str:
if hasattr(service, "get_all_api_keys"):
keys = service.get_all_api_keys()
if len(keys) != 1:
raise ValueError("Expected exactly one API key")
return keys[0]
key = getattr(service, "api_key", None)
if not key:
raise ValueError("Expected an API key")
return key

View file

@ -151,21 +151,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
masked = copy.deepcopy(config)
model_overrides = masked.get("model_overrides")
if not isinstance(model_overrides, dict):
return masked
if isinstance(model_overrides, dict):
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
for section in MODEL_OVERRIDE_FIELDS:
override = model_overrides.get(section)
if not isinstance(override, dict):
continue
for secret_field in SERVICE_SECRET_FIELDS:
raw = override.get(secret_field)
if raw:
override[secret_field] = _mask_secret_value(raw)
v2_override = masked.get("model_configuration_v2_override")
if isinstance(v2_override, dict):
_mask_nested_service_secrets(v2_override)
return masked
def _mask_nested_service_secrets(value):
if isinstance(value, dict):
for key, nested in list(value.items()):
if key in SERVICE_SECRET_FIELDS and nested:
value[key] = _mask_secret_value(nested)
else:
_mask_nested_service_secrets(nested)
elif isinstance(value, list):
for item in value:
_mask_nested_service_secrets(item)
# ---------------------------------------------------------------------------
# Workflow definition helpers mask / merge node API keys
# ---------------------------------------------------------------------------

View file

@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
)
DOGRAH_EMBEDDING_MODELS = ["default"]
@register_embeddings
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
model: str = Field(
default="default",
description="Dograh-managed embedding model.",
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
)
EmbeddingsConfig = Annotated[
Union[
OpenAIEmbeddingsConfiguration,
OpenRouterEmbeddingsConfiguration,
AzureOpenAIEmbeddingsConfiguration,
DograhEmbeddingsConfiguration,
],
Field(discriminator="provider"),
]

View file

@ -195,14 +195,17 @@ async def run_pipeline_telephony(
# Resolve effective user config here so the transport can tune its
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
# values into _run_pipeline so it doesn't fetch them again.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -272,15 +275,18 @@ async def run_pipeline_smallwebrtc(
# Resolve workflow_run + effective user_config here so the transport can
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
# reuses these via kwargs so we don't fetch twice.
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
user_config = await db_client.get_user_configurations(user_id)
run_configs = (
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id if workflow else None,
workflow_configurations=run_configs,
)
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
@ -380,11 +386,14 @@ async def _run_pipeline(
# Resolve model overrides from the version onto global user config (skip
# when the caller already resolved it).
if resolved_user_config is None:
from api.services.configuration.resolve import resolve_effective_config
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await db_client.get_user_configurations(user_id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
else:
user_config = resolved_user_config
@ -508,10 +517,17 @@ async def _run_pipeline(
embeddings_endpoint = None
embeddings_api_version = None
if user_config and user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)

View file

@ -10,8 +10,10 @@ from loguru import logger
from api.db import db_client
from api.db.models import UserModel
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
from api.services.configuration.registry import ServiceProviders
from api.services.configuration.resolve import resolve_effective_config
from api.services.mps_service_key_client import mps_service_key_client
@ -48,17 +50,20 @@ async def check_dograh_quota(
if quota is insufficient.
"""
try:
# Get user configurations
user_config = await db_client.get_user_configurations(user.id)
organization_id = user.selected_organization_id
workflow_configurations = None
if workflow_id is not None:
workflow = await db_client.get_workflow_by_id(workflow_id)
if workflow:
model_overrides = (workflow.workflow_configurations or {}).get(
"model_overrides"
)
if model_overrides:
user_config = resolve_effective_config(user_config, model_overrides)
organization_id = workflow.organization_id
workflow_configurations = workflow.workflow_configurations
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=user.id,
organization_id=organization_id,
workflow_configurations=workflow_configurations,
)
# Check if user is using any Dograh service
using_dograh = False
@ -76,6 +81,13 @@ async def check_dograh_quota(
using_dograh = True
dograh_api_keys.add(user_config.tts.api_key)
if (
user_config.embeddings
and user_config.embeddings.provider == ServiceProviders.DOGRAH
):
using_dograh = True
dograh_api_keys.add(user_config.embeddings.api_key)
# If not using Dograh, quota check passes
if not using_dograh:
return QuotaCheckResult(has_quota=True)
@ -84,7 +96,9 @@ async def check_dograh_quota(
for api_key in dograh_api_keys:
try:
usage = await mps_service_key_client.check_service_key_usage(
api_key, created_by=user.provider_id
api_key,
organization_id=organization_id,
created_by=user.provider_id,
)
remaining = usage.get("remaining_credits", 0.0)

View file

@ -2,7 +2,6 @@
import random
from api.db import db_client
from api.db.models import WorkflowRunModel
from api.services.workflow.dto import QANodeData
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
llm_config: dict = {}
if user_id:
user_configuration = await db_client.get_user_configurations(user_id)
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
workflow_configurations = {}
if workflow_run.definition:
workflow_configurations = (
workflow_run.definition.workflow_configurations or {}
)
elif workflow_run.workflow:
workflow_configurations = (
workflow_run.workflow.workflow_configurations or {}
)
user_configuration = await get_effective_ai_model_configuration_for_workflow(
user_id=user_id,
organization_id=workflow_run.workflow.organization_id
if workflow_run.workflow
else None,
workflow_configurations=workflow_configurations,
)
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
provider = llm_config.get("provider", "openai")

View file

@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
from api.db import db_client
from api.enums import WorkflowRunMode, WorkflowRunState
from api.services.configuration.resolve import resolve_effective_config
from api.services.pipecat.audio_config import create_audio_config
from api.services.pipecat.pipeline_builder import create_pipeline_task
from api.services.pipecat.pipeline_metrics_aggregator import (
@ -410,9 +409,14 @@ async def execute_text_chat_pending_turn(
run_definition = workflow_run.definition
run_configs = run_definition.workflow_configurations or {}
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
user_config = resolve_effective_config(
user_config, run_configs.get("model_overrides")
from api.services.configuration.ai_model_configuration import (
get_effective_ai_model_configuration_for_workflow,
)
user_config = await get_effective_ai_model_configuration_for_workflow(
user_id=workflow_run.workflow.user.id,
organization_id=workflow.organization_id,
workflow_configurations=run_configs,
)
if user_config.llm is None:
raise ValueError("Text chat requires an LLM configuration")
@ -466,9 +470,17 @@ async def execute_text_chat_pending_turn(
embeddings_model = None
embeddings_base_url = None
if user_config.embeddings:
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
context_compaction_enabled = (workflow.workflow_configurations or {}).get(

View file

@ -157,12 +157,24 @@ async def process_knowledge_base_document(
embeddings_endpoint = None
embeddings_api_version = None
if document.created_by:
user_config = await db_client.get_user_configurations(document.created_by)
from api.services.configuration.ai_model_configuration import (
apply_managed_embeddings_base_url,
get_resolved_ai_model_configuration,
)
resolved_config = await get_resolved_ai_model_configuration(
user_id=document.created_by,
organization_id=document.organization_id,
)
user_config = resolved_config.effective
if user_config.embeddings:
embeddings_provider = getattr(user_config.embeddings, "provider", None)
embeddings_api_key = user_config.embeddings.api_key
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
embeddings_base_url = apply_managed_embeddings_base_url(
provider=embeddings_provider,
base_url=getattr(user_config.embeddings, "base_url", None),
)
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
embeddings_api_version = getattr(
user_config.embeddings, "api_version", None

View file

@ -0,0 +1,295 @@
import pytest
from pydantic import ValidationError
from api.schemas.ai_model_configuration import (
DograhManagedAIModelConfiguration,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
from api.schemas.user_configuration import UserConfiguration
from api.services.configuration.ai_model_configuration import (
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
check_for_masked_keys_in_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_configuration_model_override_to_v2,
)
from api.services.configuration.masking import mask_key
from api.services.configuration.registry import (
DeepgramSTTConfiguration,
DograhLLMService,
DograhSTTService,
DograhTTSService,
ElevenlabsTTSConfiguration,
OpenAIEmbeddingsConfiguration,
OpenAILLMService,
)
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
config = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
voice="default",
speed=1.2,
language="multi",
),
)
effective = compile_ai_model_configuration_v2(config)
assert effective.is_realtime is False
assert effective.llm.provider == "dograh"
assert effective.llm.model == "default"
assert effective.tts.provider == "dograh"
assert effective.tts.speed == 1.2
assert effective.stt.provider == "dograh"
assert effective.stt.language == "multi"
assert effective.embeddings.provider == "dograh"
assert effective.embeddings.model == "default"
def test_dograh_v2_rejects_non_predefined_speed():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(
api_key="mps-secret",
speed=1.5,
),
)
def test_byok_v2_rejects_dograh_provider():
with pytest.raises(ValidationError):
OrganizationAIModelConfigurationV2.model_validate(
{
"mode": "byok",
"byok": {
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
"tts": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
"voice": "default",
},
"stt": {
"provider": "dograh",
"api_key": "mps-secret",
"model": "default",
},
},
},
}
)
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
existing = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
)
incoming = OrganizationAIModelConfigurationV2(
mode="dograh",
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
)
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
assert merged.dograh.api_key == "mps-real-secret"
check_for_masked_keys_in_ai_model_configuration_v2(merged)
def test_masked_v2_configuration_masks_nested_service_keys():
config = OrganizationAIModelConfigurationV2(
mode="byok",
byok={
"mode": "pipeline",
"pipeline": {
"llm": {
"provider": "openai",
"api_key": "sk-real-secret",
"model": "gpt-4.1",
},
"tts": {
"provider": "elevenlabs",
"api_key": "el-real-secret",
"model": "eleven_flash_v2_5",
"voice": "Rachel",
},
"stt": {
"provider": "deepgram",
"api_key": "dg-real-secret",
"model": "nova-3-general",
},
},
},
)
masked = mask_ai_model_configuration_v2(config)
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
legacy = UserConfiguration(
llm=DograhLLMService(
provider="dograh",
api_key=["mps-secret"],
model="default",
),
tts=DograhTTSService(
provider="dograh",
api_key=["mps-secret"],
model="default",
voice="default",
speed=1.0,
),
stt=DograhSTTService(
provider="dograh",
api_key=["mps-secret"],
model="default",
language="multi",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-secret"
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
legacy = UserConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=DograhTTSService(
provider="dograh",
api_key="mps-tts",
model="default",
voice="default",
),
stt=DograhSTTService(
provider="dograh",
api_key="mps-stt",
model="default",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "dograh"
assert config.dograh.api_key == "mps-tts"
assert config.dograh.voice == "default"
def test_legacy_byok_pipeline_converts_to_byok_v2():
legacy = UserConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
embeddings=OpenAIEmbeddingsConfiguration(
provider="openai",
api_key="sk-emb",
model="text-embedding-3-small",
),
)
config = convert_legacy_ai_model_configuration_to_v2(legacy)
assert config.mode == "byok"
assert config.byok.mode == "pipeline"
assert config.byok.pipeline.llm.provider == "openai"
assert config.byok.pipeline.tts.provider == "elevenlabs"
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
base = UserConfiguration(
llm=OpenAILLMService(
provider="openai",
api_key="sk-llm",
model="gpt-4.1",
),
tts=ElevenlabsTTSConfiguration(
provider="elevenlabs",
api_key="el-tts",
model="eleven_flash_v2_5",
voice="Rachel",
),
stt=DeepgramSTTConfiguration(
provider="deepgram",
api_key="dg-stt",
model="nova-3-general",
),
)
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": {
"tts": {
"provider": "dograh",
"api_key": "mps-workflow",
"model": "default",
"voice": "default",
}
},
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
assert v2_override["mode"] == "dograh"
assert v2_override["dograh"]["api_key"] == "mps-workflow"
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
base = UserConfiguration()
workflow_configurations = {
"ambient_noise_configuration": {"enabled": False},
"model_overrides": None,
}
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
workflow_configurations,
base,
)
assert changed is True
assert "model_overrides" not in migrated
assert migrated["ambient_noise_configuration"] == {"enabled": False}