mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: add model config v2
This commit is contained in:
parent
49e68b49d5
commit
94686b73c4
29 changed files with 4680 additions and 171 deletions
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import joinedload
|
|||
from api.db.base_client import BaseDBClient
|
||||
from api.db.filters import apply_workflow_run_filters
|
||||
from api.db.models import (
|
||||
OrganizationConfigurationModel,
|
||||
OrganizationModel,
|
||||
OrganizationUsageCycleModel,
|
||||
UserConfigurationModel,
|
||||
|
|
@ -17,6 +18,7 @@ from api.db.models import (
|
|||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
|
||||
|
||||
|
|
@ -440,8 +442,19 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"""Get daily usage breakdown for an organization with pricing."""
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Get user timezone if user_id is provided
|
||||
# Get org timezone preference first, then fall back to legacy user config.
|
||||
user_timezone = "UTC" # Default timezone
|
||||
pref_result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id,
|
||||
OrganizationConfigurationModel.key
|
||||
== OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
)
|
||||
)
|
||||
pref_obj = pref_result.scalar_one_or_none()
|
||||
if pref_obj and pref_obj.value:
|
||||
user_timezone = pref_obj.value.get("timezone") or user_timezone
|
||||
|
||||
if user_id:
|
||||
config_result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
|
|
@ -453,7 +466,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
user_config = UserConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone:
|
||||
if user_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = user_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
|
|
|
|||
|
|
@ -89,6 +89,12 @@ class OrganizationConfigurationKey(Enum):
|
|||
LANGFUSE_CREDENTIALS = (
|
||||
"LANGFUSE_CREDENTIALS" # Org-level Langfuse tracing credentials
|
||||
)
|
||||
MODEL_CONFIGURATION_V2 = (
|
||||
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
|
||||
)
|
||||
MODEL_CONFIGURATION_PREFERENCES = (
|
||||
"MODEL_CONFIGURATION_PREFERENCES" # Org-level model configuration preferences
|
||||
)
|
||||
|
||||
|
||||
class WorkflowStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
|
||||
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import create_jwt_token, hash_password, verify_password
|
||||
|
||||
|
|
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
|
|||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(user.id, mps_config)
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create default configuration for OSS user", exc_info=True
|
||||
|
|
|
|||
|
|
@ -369,6 +369,10 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import (
|
||||
AzureOpenAIEmbeddingService,
|
||||
|
|
@ -376,10 +380,15 @@ async def search_chunks(
|
|||
)
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
|
|
@ -388,6 +397,10 @@ async def search_chunks(
|
|||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
)
|
||||
|
|
@ -406,9 +419,7 @@ async def search_chunks(
|
|||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -10,6 +10,14 @@ from api.db import db_client
|
|||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
OrganizationAIModelConfigurationPreferences,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
)
|
||||
from api.schemas.telephony_config import (
|
||||
TelephonyConfigRequest,
|
||||
TelephonyConfigurationCreateRequest,
|
||||
|
|
@ -27,7 +35,28 @@ from api.schemas.telephony_phone_number import (
|
|||
ProviderSyncStatus,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import is_mask_of, mask_key
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_organization_ai_model_configuration_preferences,
|
||||
get_organization_ai_model_configuration_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_model_configurations_to_v2,
|
||||
upsert_organization_ai_model_configuration_preferences,
|
||||
upsert_organization_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.registry import (
|
||||
DOGRAH_STT_LANGUAGES,
|
||||
REGISTRY,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
from api.services.telephony.factory import get_telephony_provider_by_id
|
||||
|
|
@ -159,6 +188,208 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI model configurations v2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _require_selected_organization(user: UserModel) -> int:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[service_type].items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
|
||||
|
||||
async def _model_configuration_v2_response(
|
||||
*,
|
||||
user: UserModel,
|
||||
configuration: OrganizationAIModelConfigurationV2 | None = None,
|
||||
) -> OrganizationAIModelConfigurationResponse:
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
raw_configuration = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else resolved.organization_configuration
|
||||
)
|
||||
return OrganizationAIModelConfigurationResponse(
|
||||
configuration=mask_ai_model_configuration_v2(raw_configuration),
|
||||
effective_configuration=mask_user_config(resolved.effective),
|
||||
preferences=resolved.preferences
|
||||
or OrganizationAIModelConfigurationPreferences(),
|
||||
source=resolved.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/defaults")
|
||||
async def get_model_configuration_v2_defaults(user: UserModel = Depends(get_user)):
|
||||
_require_selected_organization(user)
|
||||
byok_default_providers = {
|
||||
service: provider
|
||||
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
return {
|
||||
"dograh": {
|
||||
"voices": [DOGRAH_DEFAULT_VOICE],
|
||||
"speeds": list(DOGRAH_SPEED_OPTIONS),
|
||||
"languages": DOGRAH_STT_LANGUAGES,
|
||||
"defaults": {
|
||||
"voice": DOGRAH_DEFAULT_VOICE,
|
||||
"speed": 1.0,
|
||||
"language": DOGRAH_DEFAULT_LANGUAGE,
|
||||
},
|
||||
},
|
||||
"byok": {
|
||||
"pipeline": {
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"tts": _byok_provider_schemas(ServiceType.TTS),
|
||||
"stt": _byok_provider_schemas(ServiceType.STT),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
"realtime": {
|
||||
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def get_model_configuration_v2(user: UserModel = Depends(get_user)):
|
||||
_require_selected_organization(user)
|
||||
return await _model_configuration_v2_response(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def save_model_configuration_v2(
|
||||
request: OrganizationAIModelConfigurationV2,
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
organization_id = _require_selected_organization(user)
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
|
||||
try:
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/migration-preview")
|
||||
async def preview_model_configuration_v2_migration(user: UserModel = Depends(get_user)):
|
||||
_require_selected_organization(user)
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
return {
|
||||
"configuration": mask_ai_model_configuration_v2(configuration),
|
||||
"effective_configuration": mask_user_config(
|
||||
compile_ai_model_configuration_v2(configuration)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-configurations/v2/migrate",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def migrate_model_configuration_v2(
|
||||
force: bool = Query(default=False),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
organization_id = _require_selected_organization(user)
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
if existing is not None and not force:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Organization already has a v2 model configuration",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
await migrate_workflow_model_configurations_to_v2(
|
||||
organization_id=organization_id,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationAIModelConfigurationPreferences,
|
||||
)
|
||||
async def get_model_configuration_preferences(user: UserModel = Depends(get_user)):
|
||||
organization_id = _require_selected_organization(user)
|
||||
return await get_organization_ai_model_configuration_preferences(organization_id)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationAIModelConfigurationPreferences,
|
||||
)
|
||||
async def save_model_configuration_preferences(
|
||||
request: OrganizationAIModelConfigurationPreferences,
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
organization_id = _require_selected_organization(user)
|
||||
return await upsert_organization_ai_model_configuration_preferences(
|
||||
organization_id,
|
||||
request,
|
||||
)
|
||||
|
||||
|
||||
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
|
||||
"""If the client re-submitted a masked sensitive field, restore the original."""
|
||||
for field_name in _sensitive_fields(provider):
|
||||
|
|
|
|||
|
|
@ -82,7 +82,15 @@ async def initiate_call(
|
|||
"""Initiate a call using the configured telephony provider from web browser. This is
|
||||
supposed to be a test call method for the draft version of the agent."""
|
||||
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_organization_ai_model_configuration_preferences,
|
||||
)
|
||||
|
||||
user_configuration = await db_client.get_user_configurations(user.id)
|
||||
preferences = await get_organization_ai_model_configuration_preferences(
|
||||
user.selected_organization_id,
|
||||
db=db_client,
|
||||
)
|
||||
|
||||
# Resolve which telephony config to use: explicit request value, otherwise
|
||||
# the org's default outbound config.
|
||||
|
|
@ -116,7 +124,11 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
phone_number = (
|
||||
request.phone_number
|
||||
or preferences.test_phone_number
|
||||
or user_configuration.test_phone_number
|
||||
)
|
||||
|
||||
if not phone_number:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@ from api.db.models import (
|
|||
UserModel,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_organization_ai_model_configuration_preferences,
|
||||
get_resolved_ai_model_configuration,
|
||||
upsert_organization_ai_model_configuration_preferences,
|
||||
)
|
||||
from api.services.configuration.check_validity import (
|
||||
APIKeyStatusResponse,
|
||||
UserConfigurationValidator,
|
||||
|
|
@ -91,8 +96,18 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
async def get_user_configurations(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> UserConfigurationRequestResponseSchema:
|
||||
user_configurations = await db_client.get_user_configurations(user.id)
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
masked_config = mask_user_config(resolved_config.effective)
|
||||
if resolved_config.preferences:
|
||||
if resolved_config.preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = (
|
||||
resolved_config.preferences.test_phone_number
|
||||
)
|
||||
if resolved_config.preferences.timezone is not None:
|
||||
masked_config["timezone"] = resolved_config.preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -144,8 +159,31 @@ async def update_user_configurations(
|
|||
user.id, user_configurations
|
||||
)
|
||||
|
||||
if user.selected_organization_id and (
|
||||
request.test_phone_number is not None or request.timezone is not None
|
||||
):
|
||||
preferences = await get_organization_ai_model_configuration_preferences(
|
||||
user.selected_organization_id
|
||||
)
|
||||
if request.test_phone_number is not None:
|
||||
preferences.test_phone_number = request.test_phone_number
|
||||
if request.timezone is not None:
|
||||
preferences.timezone = request.timezone
|
||||
await upsert_organization_ai_model_configuration_preferences(
|
||||
user.selected_organization_id,
|
||||
preferences,
|
||||
)
|
||||
|
||||
# Return masked version of updated config
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_ai_model_configuration_preferences(
|
||||
user.selected_organization_id
|
||||
)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -165,7 +203,11 @@ async def validate_user_configurations(
|
|||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> APIKeyStatusResponse:
|
||||
configurations = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
configurations = resolved_config.effective
|
||||
|
||||
if (
|
||||
configurations.last_validated_at
|
||||
|
|
|
|||
|
|
@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
|
|||
from api.db.models import UserModel
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_configurations,
|
||||
|
|
@ -955,12 +964,74 @@ async def update_workflow(
|
|||
existing_def,
|
||||
)
|
||||
|
||||
# Validate model_overrides: resolve onto global config, then
|
||||
# run the same validator used by the user-configurations endpoint.
|
||||
# Also stamp the current global API key into the override so the override
|
||||
# remains functional if the global config later switches to a different provider.
|
||||
# Validate model overrides. v2 uses a complete workflow-level model
|
||||
# configuration; legacy v1 uses partial service overlays.
|
||||
workflow_configurations = request.workflow_configurations
|
||||
if workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
if workflow_configurations and workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
existing_draft = await db_client.get_draft_version(workflow_id)
|
||||
existing_configs = (
|
||||
existing_draft.workflow_configurations
|
||||
if existing_draft
|
||||
else existing_workflow.released_definition.workflow_configurations
|
||||
)
|
||||
existing_v2_override = (existing_configs or {}).get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
try:
|
||||
incoming_v2_override = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
workflow_configurations[
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
]
|
||||
)
|
||||
)
|
||||
existing_v2_override_config = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
existing_v2_override
|
||||
)
|
||||
if existing_v2_override
|
||||
else None
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
incoming_v2_override,
|
||||
existing_v2_override_config,
|
||||
)
|
||||
if existing_v2_override_config is None:
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
v2_override,
|
||||
resolved_config.organization_configuration,
|
||||
)
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
|
||||
effective = compile_ai_model_configuration_v2(v2_override)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except (ValidationError, ValueError) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
elif workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
@ -978,24 +1049,46 @@ async def update_workflow(
|
|||
workflow_configurations,
|
||||
existing_configs,
|
||||
)
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
compile_ai_model_configuration_v2(v2_override),
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
else:
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
if resolved_config.source == "organization_v2":
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
else:
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
|
||||
# Reject upfront if any new trigger path collides with another
|
||||
# workflow's trigger — keeps the workflow record from
|
||||
|
|
|
|||
176
api/schemas/ai_model_configuration.py
Normal file
176
api/schemas/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
DograhEmbeddingsConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
ServiceProviders,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
DOGRAH_SPEED_OPTIONS: tuple[float, ...] = (0.8, 1.0, 1.2)
|
||||
DOGRAH_DEFAULT_VOICE = "default"
|
||||
DOGRAH_DEFAULT_LANGUAGE = "multi"
|
||||
|
||||
|
||||
class DograhManagedAIModelConfiguration(BaseModel):
|
||||
api_key: str
|
||||
voice: str = DOGRAH_DEFAULT_VOICE
|
||||
speed: float = Field(default=1.0)
|
||||
language: str = DOGRAH_DEFAULT_LANGUAGE
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_speed(self):
|
||||
if self.speed not in DOGRAH_SPEED_OPTIONS:
|
||||
allowed = ", ".join(str(speed) for speed in DOGRAH_SPEED_OPTIONS)
|
||||
raise ValueError(f"Dograh speed must be one of: {allowed}")
|
||||
return self
|
||||
|
||||
|
||||
class BYOKPipelineAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig
|
||||
tts: TTSConfig
|
||||
stt: STTConfig
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reject_dograh_providers(self):
|
||||
_reject_dograh_provider("llm", self.llm)
|
||||
_reject_dograh_provider("tts", self.tts)
|
||||
_reject_dograh_provider("stt", self.stt)
|
||||
_reject_dograh_provider("embeddings", self.embeddings)
|
||||
return self
|
||||
|
||||
|
||||
class BYOKRealtimeAIModelConfiguration(BaseModel):
|
||||
realtime: RealtimeConfig
|
||||
llm: LLMConfig
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reject_dograh_providers(self):
|
||||
_reject_dograh_provider("llm", self.llm)
|
||||
_reject_dograh_provider("embeddings", self.embeddings)
|
||||
return self
|
||||
|
||||
|
||||
class BYOKAIModelConfiguration(BaseModel):
|
||||
mode: Literal["pipeline", "realtime"]
|
||||
pipeline: BYOKPipelineAIModelConfiguration | None = None
|
||||
realtime: BYOKRealtimeAIModelConfiguration | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selected_mode(self):
|
||||
if self.mode == "pipeline" and self.pipeline is None:
|
||||
raise ValueError("byok.pipeline is required when byok.mode is pipeline")
|
||||
if self.mode == "realtime" and self.realtime is None:
|
||||
raise ValueError("byok.realtime is required when byok.mode is realtime")
|
||||
return self
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationV2(BaseModel):
|
||||
version: Literal[2] = 2
|
||||
mode: Literal["dograh", "byok"]
|
||||
dograh: DograhManagedAIModelConfiguration | None = None
|
||||
byok: BYOKAIModelConfiguration | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selected_mode(self):
|
||||
if self.mode == "dograh" and self.dograh is None:
|
||||
raise ValueError("dograh configuration is required when mode is dograh")
|
||||
if self.mode == "byok" and self.byok is None:
|
||||
raise ValueError("byok configuration is required when mode is byok")
|
||||
return self
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationPreferences(BaseModel):
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationResponse(BaseModel):
|
||||
configuration: dict | None
|
||||
effective_configuration: dict
|
||||
preferences: OrganizationAIModelConfigurationPreferences
|
||||
source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
|
||||
|
||||
def compile_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
if configuration.mode == "dograh":
|
||||
if configuration.dograh is None:
|
||||
raise ValueError("dograh configuration is required")
|
||||
return _compile_dograh_configuration(configuration.dograh)
|
||||
|
||||
if configuration.byok is None:
|
||||
raise ValueError("byok configuration is required")
|
||||
if configuration.byok.mode == "pipeline":
|
||||
if configuration.byok.pipeline is None:
|
||||
raise ValueError("byok.pipeline is required")
|
||||
pipeline = configuration.byok.pipeline
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=pipeline.llm,
|
||||
tts=pipeline.tts,
|
||||
stt=pipeline.stt,
|
||||
embeddings=pipeline.embeddings,
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
if configuration.byok.realtime is None:
|
||||
raise ValueError("byok.realtime is required")
|
||||
realtime = configuration.byok.realtime
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=realtime.llm,
|
||||
realtime=realtime.realtime,
|
||||
embeddings=realtime.embeddings,
|
||||
is_realtime=True,
|
||||
)
|
||||
|
||||
|
||||
def _compile_dograh_configuration(
|
||||
configuration: DograhManagedAIModelConfiguration,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
voice=configuration.voice,
|
||||
speed=configuration.speed,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
language=configuration.language,
|
||||
),
|
||||
embeddings=DograhEmbeddingsConfiguration(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
),
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
|
||||
def _reject_dograh_provider(section: str, service) -> None:
|
||||
if service is None:
|
||||
return
|
||||
if getattr(service, "provider", None) == ServiceProviders.DOGRAH:
|
||||
raise ValueError(f"BYOK {section} cannot use Dograh provider")
|
||||
|
|
@ -11,7 +11,7 @@ from api.services.configuration.registry import (
|
|||
)
|
||||
|
||||
|
||||
class UserConfiguration(BaseModel):
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
|
|
@ -31,3 +31,7 @@ class UserConfiguration(BaseModel):
|
|||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
||||
|
||||
# Backward-compatible alias for legacy persistence and existing call sites.
|
||||
UserConfiguration = EffectiveAIModelConfiguration
|
||||
|
|
|
|||
|
|
@ -119,6 +119,19 @@ async def get_user(
|
|||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
|
||||
mps_config
|
||||
)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
532
api/services/configuration/ai_model_configuration.py
Normal file
532
api/services/configuration/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,532 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from inspect import isawaitable
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
BYOKAIModelConfiguration,
|
||||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationPreferences,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
mask_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAIModelConfiguration:
|
||||
effective: EffectiveAIModelConfiguration
|
||||
source: AIModelConfigurationSource
|
||||
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
|
||||
preferences: OrganizationAIModelConfigurationPreferences | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowAIModelConfigurationMigrationResult:
|
||||
workflow_count: int = 0
|
||||
definition_count: int = 0
|
||||
workflow_ids: list[int] | None = None
|
||||
|
||||
|
||||
async def get_resolved_ai_model_configuration(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
) -> ResolvedAIModelConfiguration:
|
||||
preferences = await get_organization_ai_model_configuration_preferences(
|
||||
organization_id
|
||||
)
|
||||
organization_configuration = await get_organization_ai_model_configuration_v2(
|
||||
organization_id
|
||||
)
|
||||
if organization_configuration is not None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=compile_ai_model_configuration_v2(organization_configuration),
|
||||
source="organization_v2",
|
||||
organization_configuration=organization_configuration,
|
||||
preferences=preferences,
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=EffectiveAIModelConfiguration(),
|
||||
source="empty",
|
||||
preferences=preferences,
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user_id)
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=legacy,
|
||||
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
|
||||
preferences=preferences,
|
||||
)
|
||||
|
||||
|
||||
async def get_effective_ai_model_configuration_for_workflow(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
workflow_configurations: dict | None,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
workflow_configurations = workflow_configurations or {}
|
||||
v2_override = workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
if v2_override:
|
||||
return compile_ai_model_configuration_v2(
|
||||
OrganizationAIModelConfigurationV2.model_validate(v2_override)
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return resolve_effective_config(
|
||||
resolved_config.effective,
|
||||
workflow_configurations.get("model_overrides"),
|
||||
)
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_v2(
|
||||
organization_id: int | None,
|
||||
) -> OrganizationAIModelConfigurationV2 | None:
|
||||
if organization_id is None:
|
||||
return None
|
||||
row = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
)
|
||||
if row is None or not row.value:
|
||||
return None
|
||||
try:
|
||||
return OrganizationAIModelConfigurationV2.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration v2 for organization "
|
||||
f"{organization_id}: {exc}. Falling back to legacy configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_preferences(
|
||||
organization_id: int | None,
|
||||
db=None,
|
||||
) -> OrganizationAIModelConfigurationPreferences:
|
||||
if organization_id is None:
|
||||
return OrganizationAIModelConfigurationPreferences()
|
||||
db = db or db_client
|
||||
row = db.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
)
|
||||
if isawaitable(row):
|
||||
row = await row
|
||||
if row is None or not row.value:
|
||||
return OrganizationAIModelConfigurationPreferences()
|
||||
if not isinstance(row.value, dict):
|
||||
return OrganizationAIModelConfigurationPreferences()
|
||||
try:
|
||||
return OrganizationAIModelConfigurationPreferences.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration preferences for organization "
|
||||
f"{organization_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OrganizationAIModelConfigurationPreferences()
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_v2(
|
||||
organization_id: int,
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
configuration.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_preferences(
|
||||
organization_id: int,
|
||||
preferences: OrganizationAIModelConfigurationPreferences,
|
||||
) -> OrganizationAIModelConfigurationPreferences:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
preferences.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return preferences
|
||||
|
||||
|
||||
async def migrate_workflow_model_configurations_to_v2(
|
||||
*,
|
||||
organization_id: int,
|
||||
fallback_user_config: EffectiveAIModelConfiguration,
|
||||
) -> WorkflowAIModelConfigurationMigrationResult:
|
||||
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
|
||||
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
|
||||
workflow_updates: list[tuple[int, dict]] = []
|
||||
definition_updates: list[tuple[int, dict]] = []
|
||||
migrated_workflow_ids: set[int] = set()
|
||||
|
||||
for workflow in workflows:
|
||||
base_config = fallback_user_config
|
||||
if workflow.user_id is not None:
|
||||
if workflow.user_id not in owner_configs:
|
||||
owner_configs[
|
||||
workflow.user_id
|
||||
] = await db_client.get_user_configurations(workflow.user_id)
|
||||
base_config = owner_configs[workflow.user_id]
|
||||
|
||||
workflow_configs, workflow_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if workflow_changed:
|
||||
workflow_updates.append((workflow.id, workflow_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
for definition in workflow.definitions:
|
||||
definition_configs, definition_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
definition.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if definition_changed:
|
||||
definition_updates.append((definition.id, definition_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
if workflow_updates or definition_updates:
|
||||
async with db_client.async_session() as session:
|
||||
for workflow_id, workflow_configs in workflow_updates:
|
||||
await session.execute(
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.values(workflow_configurations=workflow_configs)
|
||||
)
|
||||
for definition_id, definition_configs in definition_updates:
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(WorkflowDefinitionModel.id == definition_id)
|
||||
.values(workflow_configurations=definition_configs)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return WorkflowAIModelConfigurationMigrationResult(
|
||||
workflow_count=len(migrated_workflow_ids),
|
||||
definition_count=len(definition_updates),
|
||||
workflow_ids=sorted(migrated_workflow_ids),
|
||||
)
|
||||
|
||||
|
||||
def migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations: dict | None,
|
||||
base_config: EffectiveAIModelConfiguration,
|
||||
) -> tuple[dict, bool]:
|
||||
if not isinstance(workflow_configurations, dict):
|
||||
return {}, False
|
||||
|
||||
migrated = copy.deepcopy(workflow_configurations)
|
||||
model_overrides = migrated.get("model_overrides")
|
||||
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
|
||||
if not isinstance(model_overrides, dict):
|
||||
if "model_overrides" in migrated:
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
return migrated, False
|
||||
|
||||
if not existing_v2_override:
|
||||
effective = resolve_effective_config(base_config, model_overrides)
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
|
||||
|
||||
def merge_ai_model_configuration_v2_secrets(
|
||||
incoming: OrganizationAIModelConfigurationV2,
|
||||
existing: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
if existing is None:
|
||||
return incoming
|
||||
|
||||
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
|
||||
existing_dict = existing.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
|
||||
incoming_dograh = incoming_dict.get("dograh") or {}
|
||||
existing_dograh = existing_dict.get("dograh") or {}
|
||||
incoming_key = incoming_dograh.get("api_key")
|
||||
existing_key = existing_dograh.get("api_key")
|
||||
if incoming_key and existing_key and contains_masked_key(incoming_key):
|
||||
incoming_dograh["api_key"] = resolve_masked_api_keys(
|
||||
incoming_key,
|
||||
existing_key,
|
||||
)
|
||||
|
||||
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
|
||||
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
|
||||
|
||||
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
|
||||
|
||||
|
||||
def check_for_masked_keys_in_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> None:
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_raise_if_masked_secret(data)
|
||||
|
||||
|
||||
def mask_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> dict | None:
|
||||
if configuration is None:
|
||||
return None
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_mask_secret_fields(data)
|
||||
return data
|
||||
|
||||
|
||||
def convert_legacy_ai_model_configuration_to_v2(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
dograh_key = _first_dograh_api_key(configuration)
|
||||
if dograh_key:
|
||||
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
|
||||
|
||||
if configuration.is_realtime:
|
||||
if configuration.realtime is None or configuration.llm is None:
|
||||
raise ValueError("Realtime legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="realtime",
|
||||
realtime=BYOKRealtimeAIModelConfiguration(
|
||||
realtime=configuration.realtime,
|
||||
llm=configuration.llm,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
configuration.llm is None
|
||||
or configuration.tts is None
|
||||
or configuration.stt is None
|
||||
):
|
||||
raise ValueError("Pipeline legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="pipeline",
|
||||
pipeline=BYOKPipelineAIModelConfiguration(
|
||||
llm=configuration.llm,
|
||||
tts=configuration.tts,
|
||||
stt=configuration.stt,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
def apply_managed_embeddings_base_url(
|
||||
*,
|
||||
provider: str | None,
|
||||
base_url: str | None,
|
||||
) -> str | None:
|
||||
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
|
||||
return dograh_embeddings_base_url()
|
||||
return base_url
|
||||
|
||||
|
||||
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
|
||||
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
|
||||
return
|
||||
incoming_mode = incoming_byok.get("mode")
|
||||
existing_mode = existing_byok.get("mode")
|
||||
if incoming_mode != existing_mode:
|
||||
return
|
||||
section_names = (
|
||||
("llm", "tts", "stt", "embeddings")
|
||||
if incoming_mode == "pipeline"
|
||||
else ("realtime", "llm", "embeddings")
|
||||
)
|
||||
incoming_container = incoming_byok.get(incoming_mode)
|
||||
existing_container = existing_byok.get(existing_mode)
|
||||
if not isinstance(incoming_container, dict) or not isinstance(
|
||||
existing_container, dict
|
||||
):
|
||||
return
|
||||
for section_name in section_names:
|
||||
incoming_section = incoming_container.get(section_name)
|
||||
existing_section = existing_container.get(section_name)
|
||||
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
|
||||
_merge_service_secret_fields(incoming_section, existing_section)
|
||||
|
||||
|
||||
async def _list_workflows_for_model_configuration_migration(
|
||||
organization_id: int,
|
||||
) -> list[WorkflowModel]:
|
||||
async with db_client.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.definitions))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
return list(result.scalars().unique().all())
|
||||
|
||||
|
||||
def _merge_service_secret_fields(incoming: dict, existing: dict):
|
||||
if (
|
||||
incoming.get("provider") is not None
|
||||
and existing.get("provider") is not None
|
||||
and incoming.get("provider") != existing.get("provider")
|
||||
):
|
||||
return
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing:
|
||||
continue
|
||||
incoming_secret = incoming.get(secret_field)
|
||||
existing_secret = existing[secret_field]
|
||||
if incoming_secret is None:
|
||||
incoming[secret_field] = existing_secret
|
||||
elif contains_masked_key(incoming_secret):
|
||||
incoming[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_masked_secret(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
|
||||
raise ValueError(
|
||||
f"The {key} appears to be masked. Please provide the actual "
|
||||
"value, not the masked value."
|
||||
)
|
||||
_raise_if_masked_secret(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_raise_if_masked_secret(item)
|
||||
|
||||
|
||||
def _mask_secret_fields(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_secret_fields(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_secret_fields(item)
|
||||
|
||||
|
||||
def _mask_secret_value(value):
|
||||
if isinstance(value, list):
|
||||
return [mask_key(item) for item in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
|
||||
return any(
|
||||
service is not None
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_any_dograh_legacy_configuration(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
dograh_key: str,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
speed = getattr(configuration.tts, "speed", 1.0)
|
||||
if speed not in DOGRAH_SPEED_OPTIONS:
|
||||
speed = 1.0
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key=dograh_key,
|
||||
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
|
||||
or DOGRAH_DEFAULT_VOICE,
|
||||
speed=speed,
|
||||
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
|
||||
or DOGRAH_DEFAULT_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
):
|
||||
if service is None or _provider(service) != ServiceProviders.DOGRAH:
|
||||
continue
|
||||
try:
|
||||
return _single_api_key(service)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _provider(service):
|
||||
return getattr(service, "provider", None)
|
||||
|
||||
|
||||
def _single_api_key(service) -> str:
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Expected exactly one API key")
|
||||
return keys[0]
|
||||
key = getattr(service, "api_key", None)
|
||||
if not key:
|
||||
raise ValueError("Expected an API key")
|
||||
return key
|
||||
|
|
@ -151,21 +151,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
|||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
if isinstance(model_overrides, dict):
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
v2_override = masked.get("model_configuration_v2_override")
|
||||
if isinstance(v2_override, dict):
|
||||
_mask_nested_service_secrets(v2_override)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
def _mask_nested_service_secrets(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_nested_service_secrets(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_nested_service_secrets(item)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
DograhEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -195,14 +195,17 @@ async def run_pipeline_telephony(
|
|||
# Resolve effective user config here so the transport can tune its
|
||||
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
|
||||
# values into _run_pipeline so it doesn't fetch them again.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -272,15 +275,18 @@ async def run_pipeline_smallwebrtc(
|
|||
# Resolve workflow_run + effective user_config here so the transport can
|
||||
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
|
||||
# reuses these via kwargs so we don't fetch twice.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -380,11 +386,14 @@ async def _run_pipeline(
|
|||
# Resolve model overrides from the version onto global user config (skip
|
||||
# when the caller already resolved it).
|
||||
if resolved_user_config is None:
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
else:
|
||||
user_config = resolved_user_config
|
||||
|
|
@ -508,10 +517,17 @@ async def _run_pipeline(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if user_config and user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
|
|
@ -48,17 +50,20 @@ async def check_dograh_quota(
|
|||
if quota is insufficient.
|
||||
"""
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
organization_id = user.selected_organization_id
|
||||
workflow_configurations = None
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
model_overrides = (workflow.workflow_configurations or {}).get(
|
||||
"model_overrides"
|
||||
)
|
||||
if model_overrides:
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
organization_id = workflow.organization_id
|
||||
workflow_configurations = workflow.workflow_configurations
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
|
|
@ -76,6 +81,13 @@ async def check_dograh_quota(
|
|||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
if (
|
||||
user_config.embeddings
|
||||
and user_config.embeddings.provider == ServiceProviders.DOGRAH
|
||||
):
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.embeddings.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
|
@ -84,7 +96,9 @@ async def check_dograh_quota(
|
|||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
|
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
|
|||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_configurations = {}
|
||||
if workflow_run.definition:
|
||||
workflow_configurations = (
|
||||
workflow_run.definition.workflow_configurations or {}
|
||||
)
|
||||
elif workflow_run.workflow:
|
||||
workflow_configurations = (
|
||||
workflow_run.workflow.workflow_configurations or {}
|
||||
)
|
||||
|
||||
user_configuration = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow_run.workflow.organization_id
|
||||
if workflow_run.workflow
|
||||
else None,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
|
|
@ -410,9 +409,14 @@ async def execute_text_chat_pending_turn(
|
|||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_run.workflow.user.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
|
@ -466,9 +470,17 @@ async def execute_text_chat_pending_turn(
|
|||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
|
|
|
|||
|
|
@ -157,12 +157,24 @@ async def process_knowledge_base_document(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(document.created_by)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=document.created_by,
|
||||
organization_id=document.organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
|
|
|
|||
295
api/tests/test_ai_model_configuration_v2.py
Normal file
295
api/tests/test_ai_model_configuration_v2.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
|
||||
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
voice="default",
|
||||
speed=1.2,
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is False
|
||||
assert effective.llm.provider == "dograh"
|
||||
assert effective.llm.model == "default"
|
||||
assert effective.tts.provider == "dograh"
|
||||
assert effective.tts.speed == 1.2
|
||||
assert effective.stt.provider == "dograh"
|
||||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
speed=1.5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_byok_v2_rejects_dograh_provider():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
|
||||
)
|
||||
incoming = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
|
||||
)
|
||||
|
||||
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
|
||||
|
||||
assert merged.dograh.api_key == "mps-real-secret"
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(merged)
|
||||
|
||||
|
||||
def test_masked_v2_configuration_masks_nested_service_keys():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok={
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"api_key": "sk-real-secret",
|
||||
"model": "gpt-4.1",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "elevenlabs",
|
||||
"api_key": "el-real-secret",
|
||||
"model": "eleven_flash_v2_5",
|
||||
"voice": "Rachel",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "deepgram",
|
||||
"api_key": "dg-real-secret",
|
||||
"model": "nova-3-general",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
masked = mask_ai_model_configuration_v2(config)
|
||||
|
||||
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
|
||||
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
|
||||
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
|
||||
|
||||
|
||||
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
speed=1.0,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-secret"
|
||||
|
||||
|
||||
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key="mps-tts",
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key="mps-stt",
|
||||
model="default",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-tts"
|
||||
assert config.dograh.voice == "default"
|
||||
|
||||
|
||||
def test_legacy_byok_pipeline_converts_to_byok_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "byok"
|
||||
assert config.byok.mode == "pipeline"
|
||||
assert config.byok.pipeline.llm.provider == "openai"
|
||||
assert config.byok.pipeline.tts.provider == "elevenlabs"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
|
||||
base = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
)
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": {
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-workflow",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
|
||||
assert v2_override["mode"] == "dograh"
|
||||
assert v2_override["dograh"]["api_key"] == "mps-workflow"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
||||
base = UserConfiguration()
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": None,
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
|
@ -1,13 +1,25 @@
|
|||
|
||||
import ServiceConfiguration from "@/components/ServiceConfiguration";
|
||||
import ModelConfigurationV2 from "@/components/ModelConfigurationV2";
|
||||
import { SETTINGS_DOCUMENTATION_URLS } from "@/constants/documentation";
|
||||
|
||||
export default function ServiceConfigurationPage() {
|
||||
interface ServiceConfigurationPageProps {
|
||||
searchParams?: Promise<{
|
||||
action?: string | string[];
|
||||
}>;
|
||||
}
|
||||
|
||||
export default async function ServiceConfigurationPage({ searchParams }: ServiceConfigurationPageProps) {
|
||||
const params = searchParams ? await searchParams : {};
|
||||
const action = Array.isArray(params.action) ? params.action[0] : params.action;
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
<ServiceConfiguration docsUrl={SETTINGS_DOCUMENTATION_URLS.modelOverrides} />
|
||||
<ModelConfigurationV2
|
||||
docsUrl={SETTINGS_DOCUMENTATION_URLS.modelOverrides}
|
||||
initialAction={action}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -7,8 +7,22 @@ import { useParams, useRouter } from "next/navigation";
|
|||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import { downloadWorkflowReportApiV1WorkflowWorkflowIdReportGet, getAmbientNoiseUploadUrlApiV1WorkflowAmbientNoiseUploadUrlPost, getWorkflowApiV1WorkflowFetchWorkflowIdGet } from "@/client/sdk.gen";
|
||||
import type { WorkflowResponse } from "@/client/types.gen";
|
||||
import {
|
||||
downloadWorkflowReportApiV1WorkflowWorkflowIdReportGet,
|
||||
getAmbientNoiseUploadUrlApiV1WorkflowAmbientNoiseUploadUrlPost,
|
||||
getModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Get,
|
||||
getModelConfigurationV2DefaultsApiV1OrganizationsModelConfigurationsV2DefaultsGet,
|
||||
getWorkflowApiV1WorkflowFetchWorkflowIdGet,
|
||||
} from "@/client/sdk.gen";
|
||||
import type {
|
||||
OrganizationAiModelConfigurationResponse,
|
||||
OrganizationAiModelConfigurationV2,
|
||||
WorkflowResponse,
|
||||
} from "@/client/types.gen";
|
||||
import {
|
||||
AIModelConfigurationV2Editor,
|
||||
type ModelConfigurationDefaultsV2,
|
||||
} from "@/components/AIModelConfigurationV2Editor";
|
||||
import { FlowEdge, FlowNode } from "@/components/flow/types";
|
||||
import { LLMConfigSelector } from "@/components/LLMConfigSelector";
|
||||
import { ServiceConfigurationForm } from "@/components/ServiceConfigurationForm";
|
||||
|
|
@ -26,6 +40,7 @@ import { Textarea } from "@/components/ui/textarea";
|
|||
import { SETTINGS_DOCUMENTATION_URLS } from "@/constants/documentation";
|
||||
import { UnsavedChangesProvider, useUnsavedChanges, useUnsavedChangesContext } from "@/context/UnsavedChangesContext";
|
||||
import { useAudioPlayback } from "@/hooks/useAudioPlayback";
|
||||
import { detailFromError } from "@/lib/apiError";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
import logger from "@/lib/logger";
|
||||
import {
|
||||
|
|
@ -1040,6 +1055,182 @@ function AgentUuidSection({ workflowUuid }: { workflowUuid: string }) {
|
|||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Section: Model Overrides
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function withoutModelConfigurationOverrides(configurations: WorkflowConfigurations): WorkflowConfigurations {
|
||||
const next = { ...configurations };
|
||||
delete next.model_overrides;
|
||||
delete next.model_configuration_v2_override;
|
||||
return next;
|
||||
}
|
||||
|
||||
function WorkflowModelOverridesSection({
|
||||
workflowConfigurations,
|
||||
workflowName,
|
||||
onSave,
|
||||
modelConfigurationDefaults,
|
||||
organizationModelConfiguration,
|
||||
modelConfigurationLoading,
|
||||
modelConfigurationError,
|
||||
}: {
|
||||
workflowConfigurations: WorkflowConfigurations;
|
||||
workflowName: string;
|
||||
onSave: (configurations: WorkflowConfigurations, workflowName: string) => Promise<void>;
|
||||
modelConfigurationDefaults: ModelConfigurationDefaultsV2 | null;
|
||||
organizationModelConfiguration: OrganizationAiModelConfigurationResponse | null;
|
||||
modelConfigurationLoading: boolean;
|
||||
modelConfigurationError: string | null;
|
||||
}) {
|
||||
const savedV2Override = workflowConfigurations.model_configuration_v2_override;
|
||||
const hasSavedModelOverride = Boolean(savedV2Override || workflowConfigurations.model_overrides);
|
||||
const [overrideEnabled, setOverrideEnabled] = useState(Boolean(savedV2Override));
|
||||
const [isRemovingOverride, setIsRemovingOverride] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setOverrideEnabled(Boolean(workflowConfigurations.model_configuration_v2_override));
|
||||
}, [workflowConfigurations.model_configuration_v2_override]);
|
||||
|
||||
const source = organizationModelConfiguration?.source || "empty";
|
||||
const isV2 = source === "organization_v2";
|
||||
|
||||
const saveLegacyOverrides = async (config: Record<string, unknown>) => {
|
||||
const nextConfigurations = withoutModelConfigurationOverrides(workflowConfigurations);
|
||||
const modelOverrides = config.model_overrides as WorkflowConfigurations["model_overrides"] | undefined;
|
||||
if (modelOverrides) {
|
||||
nextConfigurations.model_overrides = modelOverrides;
|
||||
}
|
||||
await onSave(nextConfigurations, workflowName);
|
||||
};
|
||||
|
||||
const saveV2Override = async (configuration: OrganizationAiModelConfigurationV2) => {
|
||||
const nextConfigurations = withoutModelConfigurationOverrides(workflowConfigurations);
|
||||
nextConfigurations.model_configuration_v2_override = configuration;
|
||||
await onSave(nextConfigurations, workflowName);
|
||||
toast.success("Model override saved");
|
||||
};
|
||||
|
||||
const removeV2Override = async () => {
|
||||
setIsRemovingOverride(true);
|
||||
try {
|
||||
await onSave(withoutModelConfigurationOverrides(workflowConfigurations), workflowName);
|
||||
setOverrideEnabled(false);
|
||||
toast.success("Using organization model configuration");
|
||||
} finally {
|
||||
setIsRemovingOverride(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card id="models">
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2 text-base">
|
||||
<Brain className="h-4 w-4" />
|
||||
Model Overrides
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
{isV2
|
||||
? "Override the full organization model configuration for this workflow."
|
||||
: "Override global model settings for this workflow. Toggle individual services to customize."}{" "}
|
||||
<a href={SETTINGS_DOCUMENTATION_URLS.modelOverrides} target="_blank" rel="noopener noreferrer" className="inline-flex items-center gap-0.5 underline">Learn more <ExternalLink className="h-3 w-3" /></a>
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{modelConfigurationLoading && (
|
||||
<div className="flex items-center gap-2 rounded-md border p-4 text-sm text-muted-foreground">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
Loading model configuration
|
||||
</div>
|
||||
)}
|
||||
|
||||
{modelConfigurationError && (
|
||||
<div className="rounded-md border border-destructive/40 bg-destructive/10 px-4 py-3 text-sm text-destructive">
|
||||
{modelConfigurationError}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!modelConfigurationLoading && !modelConfigurationError && !isV2 && (
|
||||
<>
|
||||
{source === "legacy_user_v1" && (
|
||||
<div className="flex flex-col gap-3 rounded-md border bg-muted/30 p-4 sm:flex-row sm:items-center sm:justify-between">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
This workflow is using legacy model overrides. Migrate organization model configuration to use v2 overrides.
|
||||
</p>
|
||||
<Button type="button" variant="outline" size="sm" asChild>
|
||||
<Link href="/model-configurations?action=migrate_to_v2">Migrate to v2</Link>
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
<ServiceConfigurationForm
|
||||
mode="override"
|
||||
currentOverrides={workflowConfigurations.model_overrides}
|
||||
submitLabel="Save Model Overrides"
|
||||
onSave={saveLegacyOverrides}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{!modelConfigurationLoading && !modelConfigurationError && isV2 && modelConfigurationDefaults && organizationModelConfiguration && (
|
||||
<>
|
||||
<div className="flex items-center justify-between rounded-md border p-4">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="workflow-model-v2-override" className="text-sm font-medium">
|
||||
Override for this workflow
|
||||
</Label>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{overrideEnabled
|
||||
? "This workflow uses its own complete model configuration."
|
||||
: "This workflow uses the organization model configuration."}
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="workflow-model-v2-override"
|
||||
checked={overrideEnabled}
|
||||
onCheckedChange={setOverrideEnabled}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{overrideEnabled ? (
|
||||
<AIModelConfigurationV2Editor
|
||||
defaults={modelConfigurationDefaults}
|
||||
configuration={
|
||||
(savedV2Override as OrganizationAiModelConfigurationV2 | undefined)
|
||||
|| (organizationModelConfiguration.configuration as OrganizationAiModelConfigurationV2 | null)
|
||||
}
|
||||
effectiveConfiguration={
|
||||
savedV2Override
|
||||
? null
|
||||
: organizationModelConfiguration.effective_configuration
|
||||
}
|
||||
submitLabel="Save Model Override"
|
||||
onSave={saveV2Override}
|
||||
/>
|
||||
) : (
|
||||
<div className="rounded-md border bg-muted/20 p-4">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Using organization model configuration.
|
||||
</p>
|
||||
{hasSavedModelOverride && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
className="mt-3"
|
||||
onClick={removeV2Override}
|
||||
disabled={isRemovingOverride}
|
||||
>
|
||||
{isRemovingOverride ? "Saving..." : "Save Organization Configuration"}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main Page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -1127,6 +1318,11 @@ function WorkflowSettingsInner({
|
|||
|
||||
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
|
||||
const [activeSection, setActiveSection] = useState("general");
|
||||
const [modelConfigurationDefaults, setModelConfigurationDefaults] = useState<ModelConfigurationDefaultsV2 | null>(null);
|
||||
const [organizationModelConfiguration, setOrganizationModelConfiguration] = useState<OrganizationAiModelConfigurationResponse | null>(null);
|
||||
const [modelConfigurationLoading, setModelConfigurationLoading] = useState(true);
|
||||
const [modelConfigurationError, setModelConfigurationError] = useState<string | null>(null);
|
||||
const hasFetchedModelConfiguration = useRef(false);
|
||||
|
||||
const workflowId = workflow.id;
|
||||
|
||||
|
|
@ -1166,6 +1362,37 @@ function WorkflowSettingsInner({
|
|||
user,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (hasFetchedModelConfiguration.current) return;
|
||||
hasFetchedModelConfiguration.current = true;
|
||||
|
||||
const loadModelConfiguration = async () => {
|
||||
setModelConfigurationLoading(true);
|
||||
setModelConfigurationError(null);
|
||||
const [defaultsResult, configurationResult] = await Promise.all([
|
||||
getModelConfigurationV2DefaultsApiV1OrganizationsModelConfigurationsV2DefaultsGet(),
|
||||
getModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Get(),
|
||||
]);
|
||||
|
||||
if (defaultsResult.error) {
|
||||
setModelConfigurationError(detailFromError(defaultsResult.error, "Failed to load model configuration defaults"));
|
||||
setModelConfigurationLoading(false);
|
||||
return;
|
||||
}
|
||||
if (configurationResult.error) {
|
||||
setModelConfigurationError(detailFromError(configurationResult.error, "Failed to load model configuration"));
|
||||
setModelConfigurationLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setModelConfigurationDefaults(defaultsResult.data as ModelConfigurationDefaultsV2);
|
||||
setOrganizationModelConfiguration(configurationResult.data || null);
|
||||
setModelConfigurationLoading(false);
|
||||
};
|
||||
|
||||
loadModelConfiguration();
|
||||
}, []);
|
||||
|
||||
// Intersection observer for active sidebar link
|
||||
useEffect(() => {
|
||||
const ids = NAV_ITEMS.map((n) => n.id);
|
||||
|
|
@ -1218,37 +1445,15 @@ function WorkflowSettingsInner({
|
|||
onSave={saveWorkflowConfigurations}
|
||||
/>
|
||||
|
||||
{/* Model Overrides */}
|
||||
<Card id="models">
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2 text-base">
|
||||
<Brain className="h-4 w-4" />
|
||||
Model Overrides
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Override global model settings for this workflow. Toggle individual services to
|
||||
customize.{" "}
|
||||
<a href={SETTINGS_DOCUMENTATION_URLS.modelOverrides} target="_blank" rel="noopener noreferrer" className="inline-flex items-center gap-0.5 underline">Learn more <ExternalLink className="h-3 w-3" /></a>
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<ServiceConfigurationForm
|
||||
mode="override"
|
||||
currentOverrides={workflowConfigurations.model_overrides}
|
||||
submitLabel="Save Model Overrides"
|
||||
onSave={async (config) => {
|
||||
await saveWorkflowConfigurations(
|
||||
{
|
||||
...workflowConfigurations,
|
||||
model_overrides:
|
||||
config.model_overrides as WorkflowConfigurations["model_overrides"],
|
||||
} as WorkflowConfigurations,
|
||||
workflowName,
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
<WorkflowModelOverridesSection
|
||||
workflowConfigurations={workflowConfigurations}
|
||||
workflowName={workflowName}
|
||||
onSave={saveWorkflowConfigurations}
|
||||
modelConfigurationDefaults={modelConfigurationDefaults}
|
||||
organizationModelConfiguration={organizationModelConfiguration}
|
||||
modelConfigurationLoading={modelConfigurationLoading}
|
||||
modelConfigurationError={modelConfigurationError}
|
||||
/>
|
||||
|
||||
{/* Template Variables */}
|
||||
<TemplateVariablesSection
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load diff
419
ui/src/components/AIModelConfigurationV2Editor.tsx
Normal file
419
ui/src/components/AIModelConfigurationV2Editor.tsx
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
"use client";
|
||||
|
||||
import { KeyRound, Save } from "lucide-react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
import type { OrganizationAiModelConfigurationV2 } from "@/client/types.gen";
|
||||
import {
|
||||
type ProviderSchema,
|
||||
type ServiceConfigurationDefaults,
|
||||
ServiceConfigurationForm,
|
||||
type ServiceSegment,
|
||||
} from "@/components/ServiceConfigurationForm";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
|
||||
|
||||
type ModelMode = "dograh" | "byok";
|
||||
|
||||
interface DograhDefaults {
|
||||
voices: string[];
|
||||
speeds: number[];
|
||||
languages: string[];
|
||||
defaults: {
|
||||
voice: string;
|
||||
speed: number;
|
||||
language: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ModelConfigurationDefaultsV2 {
|
||||
dograh: DograhDefaults;
|
||||
byok: {
|
||||
pipeline: ServiceConfigurationDefaults;
|
||||
realtime: {
|
||||
realtime: Record<string, ProviderSchema>;
|
||||
llm: Record<string, ProviderSchema>;
|
||||
embeddings: Record<string, ProviderSchema>;
|
||||
default_providers: ServiceConfigurationDefaults["default_providers"];
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
interface DograhFormState {
|
||||
api_key: string;
|
||||
voice: string;
|
||||
speed: number;
|
||||
language: string;
|
||||
}
|
||||
|
||||
interface AIModelConfigurationV2EditorProps {
|
||||
defaults: ModelConfigurationDefaultsV2;
|
||||
configuration?: OrganizationAiModelConfigurationV2 | Record<string, unknown> | null;
|
||||
effectiveConfiguration?: Record<string, unknown> | null;
|
||||
onSave: (configuration: OrganizationAiModelConfigurationV2) => Promise<void>;
|
||||
submitLabel?: string;
|
||||
}
|
||||
|
||||
function firstApiKey(value: unknown): string {
|
||||
if (Array.isArray(value)) return String(value[0] || "");
|
||||
return typeof value === "string" ? value : "";
|
||||
}
|
||||
|
||||
function asRecord(value: unknown): Record<string, unknown> | null {
|
||||
return value && typeof value === "object" && !Array.isArray(value)
|
||||
? value as Record<string, unknown>
|
||||
: null;
|
||||
}
|
||||
|
||||
function isDograhEffectiveConfig(config: Record<string, unknown> | null | undefined): boolean {
|
||||
if (!config || config.is_realtime) return false;
|
||||
const llm = asRecord(config.llm);
|
||||
const tts = asRecord(config.tts);
|
||||
const stt = asRecord(config.stt);
|
||||
return llm?.provider === "dograh" && tts?.provider === "dograh" && stt?.provider === "dograh";
|
||||
}
|
||||
|
||||
function byokDefaults(defaults: ModelConfigurationDefaultsV2): ServiceConfigurationDefaults {
|
||||
return {
|
||||
llm: defaults.byok.pipeline.llm,
|
||||
tts: defaults.byok.pipeline.tts,
|
||||
stt: defaults.byok.pipeline.stt,
|
||||
embeddings: defaults.byok.pipeline.embeddings,
|
||||
realtime: defaults.byok.realtime.realtime,
|
||||
default_providers: defaults.byok.pipeline.default_providers,
|
||||
};
|
||||
}
|
||||
|
||||
function byokConfigToLegacyShape(config: Record<string, unknown> | null): Record<string, unknown> | null {
|
||||
if (!config || config.mode !== "byok") return null;
|
||||
const byok = asRecord(config.byok);
|
||||
if (!byok) return null;
|
||||
|
||||
if (byok.mode === "realtime") {
|
||||
const realtime = asRecord(byok.realtime);
|
||||
return {
|
||||
is_realtime: true,
|
||||
realtime: realtime?.realtime,
|
||||
llm: realtime?.llm,
|
||||
embeddings: realtime?.embeddings,
|
||||
};
|
||||
}
|
||||
|
||||
const pipeline = asRecord(byok.pipeline);
|
||||
return {
|
||||
is_realtime: false,
|
||||
llm: pipeline?.llm,
|
||||
tts: pipeline?.tts,
|
||||
stt: pipeline?.stt,
|
||||
embeddings: pipeline?.embeddings,
|
||||
};
|
||||
}
|
||||
|
||||
function effectiveConfigToLegacyShape(config: Record<string, unknown> | null): Record<string, unknown> | null {
|
||||
if (!config) return null;
|
||||
return {
|
||||
is_realtime: Boolean(config.is_realtime),
|
||||
llm: config.llm,
|
||||
tts: config.tts,
|
||||
stt: config.stt,
|
||||
realtime: config.realtime,
|
||||
embeddings: config.embeddings,
|
||||
};
|
||||
}
|
||||
|
||||
function emptyByokInitialConfig(): Record<string, unknown> {
|
||||
return {
|
||||
is_realtime: false,
|
||||
};
|
||||
}
|
||||
|
||||
function getByokInitialConfig(
|
||||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
): Record<string, unknown> {
|
||||
const byokConfiguration = byokConfigToLegacyShape(configuration);
|
||||
if (byokConfiguration) return byokConfiguration;
|
||||
|
||||
if (configuration?.mode === "dograh" || isDograhEffectiveConfig(effectiveConfiguration)) {
|
||||
return emptyByokInitialConfig();
|
||||
}
|
||||
|
||||
return effectiveConfigToLegacyShape(effectiveConfiguration) || emptyByokInitialConfig();
|
||||
}
|
||||
|
||||
function buildDograhState(
|
||||
defaults: ModelConfigurationDefaultsV2,
|
||||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
): DograhFormState {
|
||||
const fallback = defaults.dograh.defaults;
|
||||
const configuredDograh = configuration?.mode === "dograh" ? asRecord(configuration.dograh) : null;
|
||||
if (configuredDograh) {
|
||||
return {
|
||||
api_key: String(configuredDograh.api_key || ""),
|
||||
voice: String(configuredDograh.voice || fallback.voice),
|
||||
speed: Number(configuredDograh.speed || fallback.speed),
|
||||
language: String(configuredDograh.language || fallback.language),
|
||||
};
|
||||
}
|
||||
|
||||
if (isDograhEffectiveConfig(effectiveConfiguration)) {
|
||||
const llm = asRecord(effectiveConfiguration?.llm);
|
||||
const tts = asRecord(effectiveConfiguration?.tts);
|
||||
const stt = asRecord(effectiveConfiguration?.stt);
|
||||
return {
|
||||
api_key: firstApiKey(llm?.api_key || tts?.api_key || stt?.api_key),
|
||||
voice: String(tts?.voice || fallback.voice),
|
||||
speed: Number(tts?.speed || fallback.speed),
|
||||
language: String(stt?.language || fallback.language),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
api_key: "",
|
||||
voice: fallback.voice,
|
||||
speed: fallback.speed,
|
||||
language: fallback.language,
|
||||
};
|
||||
}
|
||||
|
||||
function preferredMode(
|
||||
configuration: Record<string, unknown> | null,
|
||||
effectiveConfiguration: Record<string, unknown> | null,
|
||||
): ModelMode {
|
||||
if (configuration?.mode === "dograh" || configuration?.mode === "byok") {
|
||||
return configuration.mode;
|
||||
}
|
||||
return isDograhEffectiveConfig(effectiveConfiguration) ? "dograh" : "byok";
|
||||
}
|
||||
|
||||
function hasRequiredApiKey(
|
||||
service: ServiceSegment,
|
||||
serviceConfiguration: Record<string, unknown>,
|
||||
defaults: ServiceConfigurationDefaults,
|
||||
): boolean {
|
||||
const provider = serviceConfiguration.provider as string | undefined;
|
||||
if (!provider) return false;
|
||||
const providerSchema = service === "realtime"
|
||||
? defaults.realtime?.[provider]
|
||||
: defaults[service as "llm" | "tts" | "stt" | "embeddings"]?.[provider];
|
||||
const requiresApiKey = providerSchema?.required?.includes("api_key") ?? false;
|
||||
if (!requiresApiKey) return true;
|
||||
|
||||
const apiKey = serviceConfiguration.api_key;
|
||||
if (Array.isArray(apiKey)) {
|
||||
return apiKey.some((key) => typeof key === "string" && key.trim().length > 0);
|
||||
}
|
||||
return typeof apiKey === "string" && apiKey.trim().length > 0;
|
||||
}
|
||||
|
||||
function requireByokService(
|
||||
config: Record<string, unknown>,
|
||||
service: ServiceSegment,
|
||||
defaults: ServiceConfigurationDefaults,
|
||||
): Record<string, unknown> {
|
||||
const serviceConfiguration = asRecord(config[service]);
|
||||
if (
|
||||
!serviceConfiguration
|
||||
|| !serviceConfiguration.provider
|
||||
|| serviceConfiguration.provider === "dograh"
|
||||
|| !hasRequiredApiKey(service, serviceConfiguration, defaults)
|
||||
) {
|
||||
throw new Error(`${service} configuration is required`);
|
||||
}
|
||||
return serviceConfiguration;
|
||||
}
|
||||
|
||||
function optionalByokService(config: Record<string, unknown>, service: ServiceSegment): Record<string, unknown> | undefined {
|
||||
const serviceConfiguration = asRecord(config[service]);
|
||||
if (!serviceConfiguration?.provider || serviceConfiguration.provider === "dograh") return undefined;
|
||||
return serviceConfiguration;
|
||||
}
|
||||
|
||||
export function AIModelConfigurationV2Editor({
|
||||
defaults,
|
||||
configuration,
|
||||
effectiveConfiguration,
|
||||
onSave,
|
||||
submitLabel = "Save Configuration",
|
||||
}: AIModelConfigurationV2EditorProps) {
|
||||
const defaultsForByok = useMemo(() => byokDefaults(defaults), [defaults]);
|
||||
const [mode, setMode] = useState<ModelMode>("dograh");
|
||||
const [dograh, setDograh] = useState<DograhFormState>(() => ({
|
||||
api_key: "",
|
||||
voice: defaults.dograh.defaults.voice,
|
||||
speed: defaults.dograh.defaults.speed,
|
||||
language: defaults.dograh.defaults.language,
|
||||
}));
|
||||
const [byokInitialConfig, setByokInitialConfig] = useState<Record<string, unknown> | null>(null);
|
||||
const [isSavingDograh, setIsSavingDograh] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const rawConfiguration = asRecord(configuration);
|
||||
const rawEffectiveConfiguration = asRecord(effectiveConfiguration);
|
||||
setMode(preferredMode(rawConfiguration, rawEffectiveConfiguration));
|
||||
setDograh(buildDograhState(defaults, rawConfiguration, rawEffectiveConfiguration));
|
||||
setByokInitialConfig(getByokInitialConfig(rawConfiguration, rawEffectiveConfiguration));
|
||||
}, [configuration, defaults, effectiveConfiguration]);
|
||||
|
||||
const saveDograhConfiguration = async () => {
|
||||
setIsSavingDograh(true);
|
||||
setError(null);
|
||||
try {
|
||||
await onSave({
|
||||
version: 2,
|
||||
mode: "dograh",
|
||||
dograh: {
|
||||
api_key: dograh.api_key.trim(),
|
||||
voice: dograh.voice,
|
||||
speed: dograh.speed,
|
||||
language: dograh.language,
|
||||
},
|
||||
});
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "Failed to save configuration");
|
||||
} finally {
|
||||
setIsSavingDograh(false);
|
||||
}
|
||||
};
|
||||
|
||||
const saveByokConfiguration = async (config: Record<string, unknown>) => {
|
||||
setError(null);
|
||||
const isRealtime = Boolean(config.is_realtime);
|
||||
const llm = requireByokService(config, "llm", defaultsForByok);
|
||||
const embeddings = optionalByokService(config, "embeddings");
|
||||
const body: OrganizationAiModelConfigurationV2 = {
|
||||
version: 2,
|
||||
mode: "byok",
|
||||
byok: isRealtime
|
||||
? {
|
||||
mode: "realtime",
|
||||
realtime: {
|
||||
realtime: requireByokService(config, "realtime", defaultsForByok) as never,
|
||||
llm: llm as never,
|
||||
...(embeddings ? { embeddings: embeddings as never } : {}),
|
||||
},
|
||||
}
|
||||
: {
|
||||
mode: "pipeline",
|
||||
pipeline: {
|
||||
llm: llm as never,
|
||||
tts: requireByokService(config, "tts", defaultsForByok) as never,
|
||||
stt: requireByokService(config, "stt", defaultsForByok) as never,
|
||||
...(embeddings ? { embeddings: embeddings as never } : {}),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
await onSave(body);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{error && (
|
||||
<div className="rounded-md border border-destructive/40 bg-destructive/10 px-4 py-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Tabs value={mode} onValueChange={(value) => setMode(value as ModelMode)} className="space-y-6">
|
||||
<TabsList className="grid w-full grid-cols-2">
|
||||
<TabsTrigger value="dograh">Dograh</TabsTrigger>
|
||||
<TabsTrigger value="byok">BYOK</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="dograh" className="mt-0">
|
||||
<div className="rounded-lg border p-5">
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2 sm:col-span-2">
|
||||
<Label htmlFor="dograh-api-key">API Key</Label>
|
||||
<div className="relative">
|
||||
<KeyRound className="pointer-events-none absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
id="dograh-api-key"
|
||||
className="pl-9"
|
||||
value={dograh.api_key}
|
||||
onChange={(event) => setDograh({ ...dograh, api_key: event.target.value })}
|
||||
placeholder="Enter API key"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Voice</Label>
|
||||
<Select value={dograh.voice} onValueChange={(voice) => setDograh({ ...dograh, voice })}>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select voice" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{defaults.dograh.voices.map((voice) => (
|
||||
<SelectItem key={voice} value={voice}>
|
||||
{voice}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Speed</Label>
|
||||
<Select
|
||||
value={String(dograh.speed)}
|
||||
onValueChange={(speed) => setDograh({ ...dograh, speed: Number(speed) })}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select speed" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{defaults.dograh.speeds.map((speed) => (
|
||||
<SelectItem key={speed} value={String(speed)}>
|
||||
{speed}x
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2 sm:col-span-2">
|
||||
<Label>Language</Label>
|
||||
<Select value={dograh.language} onValueChange={(language) => setDograh({ ...dograh, language })}>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select language" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{defaults.dograh.languages.map((language) => (
|
||||
<SelectItem key={language} value={language}>
|
||||
{LANGUAGE_DISPLAY_NAMES[language] || language}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button type="button" className="mt-6 w-full" onClick={saveDograhConfiguration} disabled={isSavingDograh}>
|
||||
<Save className="mr-2 h-4 w-4" />
|
||||
{isSavingDograh ? "Saving..." : submitLabel}
|
||||
</Button>
|
||||
</div>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="byok" className="mt-0">
|
||||
<ServiceConfigurationForm
|
||||
key={JSON.stringify(byokInitialConfig)}
|
||||
mode="global"
|
||||
configurationDefaults={defaultsForByok}
|
||||
initialConfig={byokInitialConfig}
|
||||
submitLabel={submitLabel}
|
||||
onSave={saveByokConfiguration}
|
||||
/>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
400
ui/src/components/ModelConfigurationV2.tsx
Normal file
400
ui/src/components/ModelConfigurationV2.tsx
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
"use client";
|
||||
|
||||
import { ExternalLink, RefreshCw, Save } from "lucide-react";
|
||||
import { useEffect, useId, useRef, useState } from "react";
|
||||
import TimezoneSelect, { type ITimezoneOption } from "react-timezone-select";
|
||||
|
||||
import {
|
||||
getModelConfigurationPreferencesApiV1OrganizationsModelConfigurationsPreferencesGet,
|
||||
getModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Get,
|
||||
getModelConfigurationV2DefaultsApiV1OrganizationsModelConfigurationsV2DefaultsGet,
|
||||
migrateModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2MigratePost,
|
||||
saveModelConfigurationPreferencesApiV1OrganizationsModelConfigurationsPreferencesPut,
|
||||
saveModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Put,
|
||||
} from "@/client/sdk.gen";
|
||||
import type {
|
||||
OrganizationAiModelConfigurationPreferences,
|
||||
OrganizationAiModelConfigurationResponse,
|
||||
OrganizationAiModelConfigurationV2,
|
||||
} from "@/client/types.gen";
|
||||
import { AIModelConfigurationV2Editor, type ModelConfigurationDefaultsV2 } from "@/components/AIModelConfigurationV2Editor";
|
||||
import { ServiceConfigurationForm } from "@/components/ServiceConfigurationForm";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
import { detailFromError } from "@/lib/apiError";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
const emptyPreferences: OrganizationAiModelConfigurationPreferences = {
|
||||
test_phone_number: "",
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone || "UTC",
|
||||
};
|
||||
|
||||
const timezoneSelectStyles = {
|
||||
control: (base: Record<string, unknown>, state: { isFocused: boolean }) => ({
|
||||
...base,
|
||||
minHeight: "36px",
|
||||
fontSize: "14px",
|
||||
backgroundColor: "var(--background)",
|
||||
borderColor: state.isFocused ? "var(--ring)" : "var(--border)",
|
||||
boxShadow: state.isFocused ? "0 0 0 2px color-mix(in srgb, var(--ring) 20%, transparent)" : "none",
|
||||
"&:hover": { borderColor: "var(--border)" },
|
||||
}),
|
||||
menu: (base: Record<string, unknown>) => ({
|
||||
...base,
|
||||
zIndex: 9999,
|
||||
backgroundColor: "var(--popover)",
|
||||
border: "1px solid var(--border)",
|
||||
boxShadow: "0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)",
|
||||
}),
|
||||
menuList: (base: Record<string, unknown>) => ({
|
||||
...base,
|
||||
backgroundColor: "var(--popover)",
|
||||
padding: 0,
|
||||
}),
|
||||
option: (base: Record<string, unknown>, state: { isSelected: boolean; isFocused: boolean }) => ({
|
||||
...base,
|
||||
backgroundColor: state.isSelected ? "var(--accent)" : state.isFocused ? "var(--accent)" : "var(--popover)",
|
||||
color: "var(--foreground)",
|
||||
cursor: "pointer",
|
||||
"&:active": { backgroundColor: "var(--accent)" },
|
||||
}),
|
||||
singleValue: (base: Record<string, unknown>) => ({ ...base, color: "var(--foreground)" }),
|
||||
input: (base: Record<string, unknown>) => ({ ...base, color: "var(--foreground)" }),
|
||||
placeholder: (base: Record<string, unknown>) => ({ ...base, color: "var(--muted-foreground)" }),
|
||||
indicatorSeparator: (base: Record<string, unknown>) => ({ ...base, backgroundColor: "var(--border)" }),
|
||||
dropdownIndicator: (base: Record<string, unknown>) => ({
|
||||
...base,
|
||||
color: "var(--muted-foreground)",
|
||||
"&:hover": { color: "var(--foreground)" },
|
||||
}),
|
||||
};
|
||||
|
||||
function getTimezoneValue(tz: ITimezoneOption | string): string {
|
||||
return typeof tz === "string" ? tz : tz.value;
|
||||
}
|
||||
|
||||
export default function ModelConfigurationV2({
|
||||
docsUrl,
|
||||
initialAction,
|
||||
}: {
|
||||
docsUrl?: string;
|
||||
initialAction?: string;
|
||||
}) {
|
||||
const auth = useAuth();
|
||||
const { refreshConfig, saveUserConfig } = useUserConfig();
|
||||
const timezoneSelectId = useId();
|
||||
const hasFetched = useRef(false);
|
||||
const hasAppliedInitialMigrationAction = useRef(false);
|
||||
|
||||
const [defaults, setDefaults] = useState<ModelConfigurationDefaultsV2 | null>(null);
|
||||
const [response, setResponse] = useState<OrganizationAiModelConfigurationResponse | null>(null);
|
||||
const [preferences, setPreferences] = useState<OrganizationAiModelConfigurationPreferences>(emptyPreferences);
|
||||
const [timezone, setTimezone] = useState<ITimezoneOption | string>(emptyPreferences.timezone || "UTC");
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [savingPreferences, setSavingPreferences] = useState(false);
|
||||
const [migrating, setMigrating] = useState(false);
|
||||
const [migrationDialogOpen, setMigrationDialogOpen] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [notice, setNotice] = useState<string | null>(null);
|
||||
|
||||
const applyResponse = (nextResponse: OrganizationAiModelConfigurationResponse) => {
|
||||
setResponse(nextResponse);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (auth.loading || !auth.user || hasFetched.current) return;
|
||||
hasFetched.current = true;
|
||||
|
||||
const load = async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
const [defaultsResult, configResult, preferencesResult] = await Promise.all([
|
||||
getModelConfigurationV2DefaultsApiV1OrganizationsModelConfigurationsV2DefaultsGet(),
|
||||
getModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Get(),
|
||||
getModelConfigurationPreferencesApiV1OrganizationsModelConfigurationsPreferencesGet(),
|
||||
]);
|
||||
|
||||
if (defaultsResult.error) {
|
||||
setError(detailFromError(defaultsResult.error, "Failed to load model configuration defaults"));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
if (configResult.error) {
|
||||
setError(detailFromError(configResult.error, "Failed to load model configuration"));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
if (preferencesResult.error) {
|
||||
setError(detailFromError(preferencesResult.error, "Failed to load model configuration preferences"));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const nextDefaults = defaultsResult.data as ModelConfigurationDefaultsV2;
|
||||
if (!nextDefaults || !configResult.data) {
|
||||
setError("Failed to load model configuration");
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
setDefaults(nextDefaults);
|
||||
applyResponse(configResult.data);
|
||||
|
||||
const nextPreferences = preferencesResult.data || emptyPreferences;
|
||||
setPreferences({
|
||||
test_phone_number: nextPreferences.test_phone_number || "",
|
||||
timezone: nextPreferences.timezone || emptyPreferences.timezone,
|
||||
});
|
||||
setTimezone(nextPreferences.timezone || emptyPreferences.timezone || "UTC");
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
load();
|
||||
|
||||
}, [auth.loading, auth.user]);
|
||||
|
||||
useEffect(() => {
|
||||
if (hasAppliedInitialMigrationAction.current) return;
|
||||
if (initialAction !== "migrate_to_v2") return;
|
||||
if (loading || response?.source !== "legacy_user_v1") return;
|
||||
hasAppliedInitialMigrationAction.current = true;
|
||||
setMigrationDialogOpen(true);
|
||||
}, [initialAction, loading, response?.source]);
|
||||
|
||||
const saveConfiguration = async (configuration: OrganizationAiModelConfigurationV2) => {
|
||||
if (!defaults) return;
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
|
||||
const result = await saveModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Put({
|
||||
body: configuration,
|
||||
});
|
||||
|
||||
if (result.error) {
|
||||
throw new Error(detailFromError(result.error, "Failed to save model configuration"));
|
||||
}
|
||||
if (!result.data) {
|
||||
throw new Error("Failed to save model configuration");
|
||||
}
|
||||
|
||||
applyResponse(result.data);
|
||||
await refreshConfig();
|
||||
setNotice("Model configuration saved");
|
||||
};
|
||||
|
||||
const savePreferences = async () => {
|
||||
setSavingPreferences(true);
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
|
||||
const result = await saveModelConfigurationPreferencesApiV1OrganizationsModelConfigurationsPreferencesPut({
|
||||
body: {
|
||||
test_phone_number: preferences.test_phone_number || null,
|
||||
timezone: getTimezoneValue(timezone),
|
||||
},
|
||||
});
|
||||
|
||||
if (result.error) {
|
||||
setError(detailFromError(result.error, "Failed to save preferences"));
|
||||
} else if (!result.data) {
|
||||
setError("Failed to save preferences");
|
||||
} else {
|
||||
setPreferences(result.data);
|
||||
await refreshConfig();
|
||||
setNotice("Preferences saved");
|
||||
}
|
||||
setSavingPreferences(false);
|
||||
};
|
||||
|
||||
const migrateConfiguration = async () => {
|
||||
if (!defaults) return;
|
||||
setMigrating(true);
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
|
||||
const result = await migrateModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2MigratePost();
|
||||
if (result.error) {
|
||||
setError(detailFromError(result.error, "Failed to migrate model configuration"));
|
||||
} else if (!result.data) {
|
||||
setError("Failed to migrate model configuration");
|
||||
} else {
|
||||
applyResponse(result.data);
|
||||
await refreshConfig();
|
||||
setNotice("Configuration migrated to v2");
|
||||
setMigrationDialogOpen(false);
|
||||
}
|
||||
setMigrating(false);
|
||||
};
|
||||
|
||||
const migrationWarningDialog = (
|
||||
<AlertDialog open={migrationDialogOpen} onOpenChange={setMigrationDialogOpen}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Migrate model configuration to v2?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
Your configurations will be migrated to v2. After migration, check your global configuration and workflow model overrides, then run a test call to make sure everything is working.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={migrating}>Cancel</AlertDialogCancel>
|
||||
<Button type="button" onClick={migrateConfiguration} disabled={migrating}>
|
||||
{migrating ? "Migrating..." : "Migrate to v2"}
|
||||
</Button>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
);
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="w-full max-w-4xl mx-auto space-y-6">
|
||||
<Skeleton className="h-10 w-80" />
|
||||
<Skeleton className="h-28 w-full" />
|
||||
<Skeleton className="h-96 w-full" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const source = response?.source || "empty";
|
||||
|
||||
if (source !== "organization_v2") {
|
||||
return (
|
||||
<div className="w-full max-w-4xl mx-auto space-y-6">
|
||||
<div className="flex flex-col gap-3 sm:flex-row sm:items-start sm:justify-between">
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<h1 className="text-3xl font-bold">AI Models Configuration</h1>
|
||||
<Badge variant="outline">
|
||||
{source === "legacy_user_v1" ? "legacy" : "v1"}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-muted-foreground">
|
||||
Configure your AI model, voice, and transcription services.{" "}
|
||||
{docsUrl && (
|
||||
<a href={docsUrl} target="_blank" rel="noopener noreferrer" className="inline-flex items-center gap-0.5 underline">
|
||||
Learn more <ExternalLink className="h-3 w-3" />
|
||||
</a>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
{source === "legacy_user_v1" && (
|
||||
<Button type="button" variant="outline" onClick={() => setMigrationDialogOpen(true)} disabled={migrating}>
|
||||
<RefreshCw className="mr-2 h-4 w-4" />
|
||||
{migrating ? "Migrating..." : "Migrate to v2"}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-md border border-destructive/40 bg-destructive/10 px-4 py-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
{notice && (
|
||||
<div className="rounded-md border border-green-500/40 bg-green-500/10 px-4 py-3 text-sm text-green-700 dark:text-green-300">
|
||||
{notice}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ServiceConfigurationForm
|
||||
mode="global"
|
||||
onSave={async (config) => {
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
await saveUserConfig(config as Parameters<typeof saveUserConfig>[0]);
|
||||
await refreshConfig();
|
||||
if (defaults) {
|
||||
const configResult = await getModelConfigurationV2ApiV1OrganizationsModelConfigurationsV2Get();
|
||||
if (configResult.data) {
|
||||
applyResponse(configResult.data);
|
||||
}
|
||||
}
|
||||
setNotice("Configuration saved");
|
||||
}}
|
||||
/>
|
||||
{migrationWarningDialog}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-4xl mx-auto space-y-6">
|
||||
<div className="flex flex-col gap-3 sm:flex-row sm:items-start sm:justify-between">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">AI Models Configuration</h1>
|
||||
<p className="mt-2 text-sm text-muted-foreground">
|
||||
Organization-scoped model settings.{" "}
|
||||
{docsUrl && (
|
||||
<a href={docsUrl} target="_blank" rel="noopener noreferrer" className="inline-flex items-center gap-0.5 underline">
|
||||
Learn more <ExternalLink className="h-3 w-3" />
|
||||
</a>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="rounded-md border border-destructive/40 bg-destructive/10 px-4 py-3 text-sm text-destructive">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
{notice && (
|
||||
<div className="rounded-md border border-green-500/40 bg-green-500/10 px-4 py-3 text-sm text-green-700 dark:text-green-300">
|
||||
{notice}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{defaults && response && (
|
||||
<AIModelConfigurationV2Editor
|
||||
defaults={defaults}
|
||||
configuration={response.configuration}
|
||||
effectiveConfiguration={response.effective_configuration}
|
||||
onSave={saveConfiguration}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="rounded-lg border p-5">
|
||||
<div className="mb-4">
|
||||
<h2 className="text-base font-semibold">Preferences</h2>
|
||||
</div>
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="test-phone-number">Test Phone Number</Label>
|
||||
<Input
|
||||
id="test-phone-number"
|
||||
value={preferences.test_phone_number || ""}
|
||||
onChange={(event) => setPreferences({ ...preferences, test_phone_number: event.target.value })}
|
||||
placeholder="+15551234567"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label>Timezone</Label>
|
||||
<TimezoneSelect
|
||||
instanceId={timezoneSelectId}
|
||||
value={timezone}
|
||||
onChange={setTimezone}
|
||||
styles={timezoneSelectStyles}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<Button type="button" variant="outline" className="mt-5" onClick={savePreferences} disabled={savingPreferences}>
|
||||
<Save className="mr-2 h-4 w-4" />
|
||||
{savingPreferences ? "Saving..." : "Save Preferences"}
|
||||
</Button>
|
||||
</div>
|
||||
{migrationWarningDialog}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -19,7 +19,7 @@ import { LANGUAGE_DISPLAY_NAMES } from "@/constants/languages";
|
|||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
import type { ModelOverrides } from "@/types/workflow-configurations";
|
||||
|
||||
type ServiceSegment = "llm" | "tts" | "stt" | "embeddings" | "realtime";
|
||||
export type ServiceSegment = "llm" | "tts" | "stt" | "embeddings" | "realtime";
|
||||
|
||||
interface SchemaProperty {
|
||||
type?: string;
|
||||
|
|
@ -35,7 +35,7 @@ interface SchemaProperty {
|
|||
docs_url?: string;
|
||||
}
|
||||
|
||||
interface ProviderSchema {
|
||||
export interface ProviderSchema {
|
||||
title?: string;
|
||||
description?: string;
|
||||
provider_docs_url?: string;
|
||||
|
|
@ -49,6 +49,15 @@ interface FormValues {
|
|||
[key: string]: string | number | boolean;
|
||||
}
|
||||
|
||||
export interface ServiceConfigurationDefaults {
|
||||
llm: Record<string, ProviderSchema>;
|
||||
tts: Record<string, ProviderSchema>;
|
||||
stt: Record<string, ProviderSchema>;
|
||||
embeddings: Record<string, ProviderSchema>;
|
||||
realtime?: Record<string, ProviderSchema>;
|
||||
default_providers: Partial<Record<ServiceSegment, string>>;
|
||||
}
|
||||
|
||||
const STANDARD_TABS: { key: ServiceSegment; label: string }[] = [
|
||||
{ key: "llm", label: "LLM" },
|
||||
{ key: "tts", label: "Voice" },
|
||||
|
|
@ -90,6 +99,8 @@ export interface ServiceConfigurationFormProps {
|
|||
onSave: (config: Record<string, unknown>) => Promise<void>;
|
||||
/** Text for the submit button. Defaults to "Save Configuration". */
|
||||
submitLabel?: string;
|
||||
configurationDefaults?: ServiceConfigurationDefaults | null;
|
||||
initialConfig?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
function getProviderDisplayName(
|
||||
|
|
@ -117,6 +128,8 @@ export function ServiceConfigurationForm({
|
|||
currentOverrides,
|
||||
onSave,
|
||||
submitLabel,
|
||||
configurationDefaults,
|
||||
initialConfig,
|
||||
}: ServiceConfigurationFormProps) {
|
||||
const [apiError, setApiError] = useState<string | null>(null);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
|
|
@ -165,15 +178,16 @@ export function ServiceConfigurationForm({
|
|||
|
||||
// Build effective config source: overlay overrides onto global config
|
||||
const configSource = useMemo(() => {
|
||||
if (mode === 'global' || !currentOverrides) return userConfig;
|
||||
const baseConfig = initialConfig ?? userConfig;
|
||||
if (mode === 'global' || !currentOverrides) return baseConfig;
|
||||
// Merge overrides onto global config for form initialization
|
||||
const merged = { ...userConfig } as Record<string, unknown>;
|
||||
const merged = { ...baseConfig } as Record<string, unknown>;
|
||||
const overrideServices: (keyof ModelOverrides)[] = ["llm", "tts", "stt", "realtime"];
|
||||
for (const svc of overrideServices) {
|
||||
if (svc === "is_realtime") continue;
|
||||
const overrideVal = currentOverrides[svc];
|
||||
if (overrideVal && typeof overrideVal === "object") {
|
||||
const globalVal = (userConfig as Record<string, unknown> | null)?.[svc] as Record<string, unknown> | undefined;
|
||||
const globalVal = (baseConfig as Record<string, unknown> | null)?.[svc] as Record<string, unknown> | undefined;
|
||||
merged[svc] = { ...globalVal, ...overrideVal };
|
||||
}
|
||||
}
|
||||
|
|
@ -181,24 +195,35 @@ export function ServiceConfigurationForm({
|
|||
merged.is_realtime = currentOverrides.is_realtime;
|
||||
}
|
||||
return merged as typeof userConfig;
|
||||
}, [mode, userConfig, currentOverrides]);
|
||||
}, [mode, userConfig, currentOverrides, initialConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchConfigurations = async () => {
|
||||
const response = await getDefaultConfigurationsApiV1UserConfigurationsDefaultsGet();
|
||||
if (!response.data) {
|
||||
console.error("Failed to fetch configurations");
|
||||
return;
|
||||
let defaultsData = configurationDefaults;
|
||||
if (!defaultsData) {
|
||||
const response = await getDefaultConfigurationsApiV1UserConfigurationsDefaultsGet();
|
||||
if (!response.data) {
|
||||
console.error("Failed to fetch configurations");
|
||||
return;
|
||||
}
|
||||
defaultsData = response.data as ServiceConfigurationDefaults;
|
||||
}
|
||||
|
||||
const data = response.data as Record<string, unknown>;
|
||||
const realtimeSchemas = (data.realtime || {}) as Record<string, ProviderSchema>;
|
||||
const realtimeSchemas = (defaultsData.realtime || {}) as Record<string, ProviderSchema>;
|
||||
const pickDefaultProvider = (
|
||||
service: ServiceSegment,
|
||||
schemaMap: Record<string, ProviderSchema>,
|
||||
) => {
|
||||
const preferred = defaultsData.default_providers?.[service];
|
||||
if (preferred && schemaMap[preferred]) return preferred;
|
||||
return Object.keys(schemaMap)[0] || "";
|
||||
};
|
||||
|
||||
setSchemas({
|
||||
llm: response.data.llm as Record<string, ProviderSchema>,
|
||||
tts: response.data.tts as Record<string, ProviderSchema>,
|
||||
stt: response.data.stt as Record<string, ProviderSchema>,
|
||||
embeddings: response.data.embeddings as Record<string, ProviderSchema>,
|
||||
llm: defaultsData.llm,
|
||||
tts: defaultsData.tts,
|
||||
stt: defaultsData.stt,
|
||||
embeddings: defaultsData.embeddings,
|
||||
realtime: realtimeSchemas,
|
||||
});
|
||||
|
||||
|
|
@ -210,10 +235,10 @@ export function ServiceConfigurationForm({
|
|||
|
||||
const defaultValues: Record<string, string | number | boolean> = {};
|
||||
const selectedProviders: Record<ServiceSegment, string> = {
|
||||
llm: response.data.default_providers.llm,
|
||||
tts: response.data.default_providers.tts,
|
||||
stt: response.data.default_providers.stt,
|
||||
embeddings: response.data.default_providers.embeddings,
|
||||
llm: pickDefaultProvider("llm", defaultsData.llm),
|
||||
tts: pickDefaultProvider("tts", defaultsData.tts),
|
||||
stt: pickDefaultProvider("stt", defaultsData.stt),
|
||||
embeddings: pickDefaultProvider("embeddings", defaultsData.embeddings),
|
||||
realtime: "",
|
||||
};
|
||||
|
||||
|
|
@ -237,7 +262,7 @@ export function ServiceConfigurationForm({
|
|||
|
||||
const schemaSource = service === "realtime"
|
||||
? realtimeSchemas
|
||||
: response.data![service as "llm" | "tts" | "stt" | "embeddings"] as Record<string, ProviderSchema> | undefined;
|
||||
: defaultsData[service as "llm" | "tts" | "stt" | "embeddings"] as Record<string, ProviderSchema> | undefined;
|
||||
|
||||
if (src?.provider) {
|
||||
Object.entries(src).forEach(([field, value]) => {
|
||||
|
|
@ -296,7 +321,7 @@ export function ServiceConfigurationForm({
|
|||
|
||||
// Detect custom inputs
|
||||
const detectedCustomInput: Record<string, boolean> = {};
|
||||
const allSchemas = { ...response.data, realtime: realtimeSchemas } as unknown as Record<string, Record<string, ProviderSchema>>;
|
||||
const allSchemas = { ...defaultsData, realtime: realtimeSchemas } as unknown as Record<string, Record<string, ProviderSchema>>;
|
||||
(["llm", "tts", "stt", "embeddings", "realtime"] as ServiceSegment[]).forEach(service => {
|
||||
const provider = selectedProviders[service];
|
||||
const providerSchema = allSchemas[service]?.[provider];
|
||||
|
|
@ -337,7 +362,7 @@ export function ServiceConfigurationForm({
|
|||
};
|
||||
fetchConfigurations();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [reset, configSource]);
|
||||
}, [reset, configSource, configurationDefaults]);
|
||||
|
||||
// Reset voice when TTS model changes if the provider has model-dependent voice options
|
||||
const ttsModel = watch("tts_model");
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import type { OrganizationAiModelConfigurationV2 } from "@/client/types.gen";
|
||||
|
||||
export interface AmbientNoiseConfiguration {
|
||||
enabled: boolean;
|
||||
volume: number;
|
||||
|
|
@ -64,6 +66,7 @@ export interface WorkflowConfigurations {
|
|||
voicemail_detection?: VoicemailDetectionConfiguration;
|
||||
context_compaction_enabled?: boolean; // Summarize context on node transitions to remove stale tool calls
|
||||
model_overrides?: ModelOverrides; // Per-workflow model configuration overrides
|
||||
model_configuration_v2_override?: OrganizationAiModelConfigurationV2; // Full v2 model configuration override
|
||||
[key: string]: unknown; // Allow additional properties for future configurations
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue