mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Merge remote-tracking branch 'origin/main' into feat/user-onboarding
This commit is contained in:
commit
093e888ce4
148 changed files with 10908 additions and 2815 deletions
484
api/services/configuration/ai_model_configuration.py
Normal file
484
api/services/configuration/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
BYOKAIModelConfiguration,
|
||||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
EffectiveAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
mask_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAIModelConfiguration:
|
||||
effective: EffectiveAIModelConfiguration
|
||||
source: AIModelConfigurationSource
|
||||
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowAIModelConfigurationMigrationResult:
|
||||
workflow_count: int = 0
|
||||
definition_count: int = 0
|
||||
workflow_ids: list[int] | None = None
|
||||
|
||||
|
||||
async def get_resolved_ai_model_configuration(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
) -> ResolvedAIModelConfiguration:
|
||||
organization_configuration = await get_organization_ai_model_configuration_v2(
|
||||
organization_id
|
||||
)
|
||||
if organization_configuration is not None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=compile_ai_model_configuration_v2(organization_configuration),
|
||||
source="organization_v2",
|
||||
organization_configuration=organization_configuration,
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=EffectiveAIModelConfiguration(),
|
||||
source="empty",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user_id)
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=legacy,
|
||||
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
|
||||
)
|
||||
|
||||
|
||||
async def get_effective_ai_model_configuration_for_workflow(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
workflow_configurations: dict | None,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
workflow_configurations = workflow_configurations or {}
|
||||
v2_override = workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
if v2_override:
|
||||
return compile_ai_model_configuration_v2(
|
||||
OrganizationAIModelConfigurationV2.model_validate(v2_override)
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return resolve_effective_config(
|
||||
resolved_config.effective,
|
||||
workflow_configurations.get("model_overrides"),
|
||||
)
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_v2(
|
||||
organization_id: int | None,
|
||||
) -> OrganizationAIModelConfigurationV2 | None:
|
||||
if organization_id is None:
|
||||
return None
|
||||
row = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
)
|
||||
if row is None or not row.value:
|
||||
return None
|
||||
try:
|
||||
return OrganizationAIModelConfigurationV2.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration v2 for organization "
|
||||
f"{organization_id}: {exc}. Falling back to legacy configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_v2(
|
||||
organization_id: int,
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
configuration.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
||||
async def migrate_workflow_model_configurations_to_v2(
|
||||
*,
|
||||
organization_id: int,
|
||||
fallback_user_config: EffectiveAIModelConfiguration,
|
||||
) -> WorkflowAIModelConfigurationMigrationResult:
|
||||
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
|
||||
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
|
||||
workflow_updates: list[tuple[int, dict]] = []
|
||||
definition_updates: list[tuple[int, dict]] = []
|
||||
migrated_workflow_ids: set[int] = set()
|
||||
|
||||
for workflow in workflows:
|
||||
base_config = fallback_user_config
|
||||
if workflow.user_id is not None:
|
||||
if workflow.user_id not in owner_configs:
|
||||
owner_configs[
|
||||
workflow.user_id
|
||||
] = await db_client.get_user_configurations(workflow.user_id)
|
||||
base_config = owner_configs[workflow.user_id]
|
||||
|
||||
workflow_configs, workflow_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if workflow_changed:
|
||||
workflow_updates.append((workflow.id, workflow_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
for definition in workflow.definitions:
|
||||
definition_configs, definition_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
definition.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if definition_changed:
|
||||
definition_updates.append((definition.id, definition_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
if workflow_updates or definition_updates:
|
||||
async with db_client.async_session() as session:
|
||||
for workflow_id, workflow_configs in workflow_updates:
|
||||
await session.execute(
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.values(workflow_configurations=workflow_configs)
|
||||
)
|
||||
for definition_id, definition_configs in definition_updates:
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(WorkflowDefinitionModel.id == definition_id)
|
||||
.values(workflow_configurations=definition_configs)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return WorkflowAIModelConfigurationMigrationResult(
|
||||
workflow_count=len(migrated_workflow_ids),
|
||||
definition_count=len(definition_updates),
|
||||
workflow_ids=sorted(migrated_workflow_ids),
|
||||
)
|
||||
|
||||
|
||||
def migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations: dict | None,
|
||||
base_config: EffectiveAIModelConfiguration,
|
||||
) -> tuple[dict, bool]:
|
||||
if not isinstance(workflow_configurations, dict):
|
||||
return {}, False
|
||||
|
||||
migrated = copy.deepcopy(workflow_configurations)
|
||||
model_overrides = migrated.get("model_overrides")
|
||||
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
|
||||
if not isinstance(model_overrides, dict):
|
||||
if "model_overrides" in migrated:
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
return migrated, False
|
||||
|
||||
if not existing_v2_override:
|
||||
effective = resolve_effective_config(base_config, model_overrides)
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
|
||||
|
||||
def merge_ai_model_configuration_v2_secrets(
|
||||
incoming: OrganizationAIModelConfigurationV2,
|
||||
existing: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
if existing is None:
|
||||
return incoming
|
||||
|
||||
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
|
||||
existing_dict = existing.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
|
||||
incoming_dograh = incoming_dict.get("dograh") or {}
|
||||
existing_dograh = existing_dict.get("dograh") or {}
|
||||
incoming_key = incoming_dograh.get("api_key")
|
||||
existing_key = existing_dograh.get("api_key")
|
||||
if incoming_key and existing_key and contains_masked_key(incoming_key):
|
||||
incoming_dograh["api_key"] = resolve_masked_api_keys(
|
||||
incoming_key,
|
||||
existing_key,
|
||||
)
|
||||
|
||||
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
|
||||
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
|
||||
|
||||
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
|
||||
|
||||
|
||||
def check_for_masked_keys_in_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> None:
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_raise_if_masked_secret(data)
|
||||
|
||||
|
||||
def mask_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> dict | None:
|
||||
if configuration is None:
|
||||
return None
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_mask_secret_fields(data)
|
||||
return data
|
||||
|
||||
|
||||
def convert_legacy_ai_model_configuration_to_v2(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
dograh_key = _first_dograh_api_key(configuration)
|
||||
if dograh_key:
|
||||
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
|
||||
|
||||
if configuration.is_realtime:
|
||||
if configuration.realtime is None or configuration.llm is None:
|
||||
raise ValueError("Realtime legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="realtime",
|
||||
realtime=BYOKRealtimeAIModelConfiguration(
|
||||
realtime=configuration.realtime,
|
||||
llm=configuration.llm,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
configuration.llm is None
|
||||
or configuration.tts is None
|
||||
or configuration.stt is None
|
||||
):
|
||||
raise ValueError("Pipeline legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="pipeline",
|
||||
pipeline=BYOKPipelineAIModelConfiguration(
|
||||
llm=configuration.llm,
|
||||
tts=configuration.tts,
|
||||
stt=configuration.stt,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
def apply_managed_embeddings_base_url(
|
||||
*,
|
||||
provider: str | None,
|
||||
base_url: str | None,
|
||||
) -> str | None:
|
||||
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
|
||||
return dograh_embeddings_base_url()
|
||||
return base_url
|
||||
|
||||
|
||||
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
|
||||
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
|
||||
return
|
||||
incoming_mode = incoming_byok.get("mode")
|
||||
existing_mode = existing_byok.get("mode")
|
||||
if incoming_mode != existing_mode:
|
||||
return
|
||||
section_names = (
|
||||
("llm", "tts", "stt", "embeddings")
|
||||
if incoming_mode == "pipeline"
|
||||
else ("realtime", "llm", "embeddings")
|
||||
)
|
||||
incoming_container = incoming_byok.get(incoming_mode)
|
||||
existing_container = existing_byok.get(existing_mode)
|
||||
if not isinstance(incoming_container, dict) or not isinstance(
|
||||
existing_container, dict
|
||||
):
|
||||
return
|
||||
for section_name in section_names:
|
||||
incoming_section = incoming_container.get(section_name)
|
||||
existing_section = existing_container.get(section_name)
|
||||
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
|
||||
_merge_service_secret_fields(incoming_section, existing_section)
|
||||
|
||||
|
||||
async def _list_workflows_for_model_configuration_migration(
|
||||
organization_id: int,
|
||||
) -> list[WorkflowModel]:
|
||||
async with db_client.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.definitions))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
return list(result.scalars().unique().all())
|
||||
|
||||
|
||||
def _merge_service_secret_fields(incoming: dict, existing: dict):
|
||||
if (
|
||||
incoming.get("provider") is not None
|
||||
and existing.get("provider") is not None
|
||||
and incoming.get("provider") != existing.get("provider")
|
||||
):
|
||||
return
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing:
|
||||
continue
|
||||
incoming_secret = incoming.get(secret_field)
|
||||
existing_secret = existing[secret_field]
|
||||
if incoming_secret is None:
|
||||
incoming[secret_field] = existing_secret
|
||||
elif contains_masked_key(incoming_secret):
|
||||
incoming[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_masked_secret(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
|
||||
raise ValueError(
|
||||
f"The {key} appears to be masked. Please provide the actual "
|
||||
"value, not the masked value."
|
||||
)
|
||||
_raise_if_masked_secret(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_raise_if_masked_secret(item)
|
||||
|
||||
|
||||
def _mask_secret_fields(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_secret_fields(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_secret_fields(item)
|
||||
|
||||
|
||||
def _mask_secret_value(value):
|
||||
if isinstance(value, list):
|
||||
return [mask_key(item) for item in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
|
||||
return any(
|
||||
service is not None
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_any_dograh_legacy_configuration(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
dograh_key: str,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
speed = getattr(configuration.tts, "speed", 1.0)
|
||||
if speed not in DOGRAH_SPEED_OPTIONS:
|
||||
speed = 1.0
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key=dograh_key,
|
||||
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
|
||||
or DOGRAH_DEFAULT_VOICE,
|
||||
speed=speed,
|
||||
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
|
||||
or DOGRAH_DEFAULT_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
):
|
||||
if service is None or _provider(service) != ServiceProviders.DOGRAH:
|
||||
continue
|
||||
try:
|
||||
return _single_api_key(service)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _provider(service):
|
||||
return getattr(service, "provider", None)
|
||||
|
||||
|
||||
def _single_api_key(service) -> str:
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Expected exactly one API key")
|
||||
return keys[0]
|
||||
key = getattr(service, "api_key", None)
|
||||
if not key:
|
||||
raise ValueError("Expected an API key")
|
||||
return key
|
||||
|
|
@ -8,8 +8,8 @@ from groq import Groq
|
|||
# from pyneuphonic import Neuphonic
|
||||
# except ImportError:
|
||||
# Neuphonic = None
|
||||
from api.schemas.user_configuration import (
|
||||
UserConfiguration,
|
||||
from api.schemas.ai_model_configuration import (
|
||||
EffectiveAIModelConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
|
@ -64,7 +64,7 @@ class UserConfigurationValidator:
|
|||
|
||||
async def validate(
|
||||
self,
|
||||
configuration: UserConfiguration,
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> APIKeyStatusResponse:
|
||||
|
|
@ -75,21 +75,21 @@ class UserConfigurationValidator:
|
|||
status_list = []
|
||||
|
||||
status_list.extend(self._validate_service(configuration.llm, "llm"))
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
# Realtime is optional - only validate if is_realtime is enabled
|
||||
if configuration.is_realtime:
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.realtime, "realtime", required=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
status_list.extend(self._validate_service(configuration.stt, "stt"))
|
||||
status_list.extend(self._validate_service(configuration.tts, "tts"))
|
||||
# Embeddings is optional - only validate if configured
|
||||
status_list.extend(
|
||||
self._validate_service(
|
||||
configuration.embeddings, "embeddings", required=False
|
||||
)
|
||||
)
|
||||
|
||||
if status_list:
|
||||
raise ValueError(status_list)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ The rules are simple:
|
|||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import ServiceConfig
|
||||
from api.services.integrations import get_node_secret_fields
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ def contains_masked_key(value: str | list[str] | None) -> bool:
|
|||
return any(MASK_MARKER in k for k in keys)
|
||||
|
||||
|
||||
def check_for_masked_keys(config: "UserConfiguration") -> None:
|
||||
def check_for_masked_keys(config: "EffectiveAIModelConfiguration") -> None:
|
||||
"""Raise ValueError if any service in *config* still has a masked secret."""
|
||||
for field in ("llm", "tts", "stt", "embeddings", "realtime"):
|
||||
service = getattr(config, field, None)
|
||||
|
|
@ -111,7 +111,7 @@ def resolve_masked_api_keys(
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level helpers for UserConfiguration objects
|
||||
# High-level helpers for EffectiveAIModelConfiguration objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ def _mask_service(service_cfg: Optional[ServiceConfig]) -> Optional[Dict[str, An
|
|||
return data
|
||||
|
||||
|
||||
def mask_user_config(config: UserConfiguration) -> Dict[str, Any]:
|
||||
def mask_user_config(config: EffectiveAIModelConfiguration) -> Dict[str, Any]:
|
||||
"""Return a JSON-serialisable dict of *config* with every api_key masked."""
|
||||
|
||||
return {
|
||||
|
|
@ -155,21 +155,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
|||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
if isinstance(model_overrides, dict):
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
v2_override = masked.get("model_configuration_v2_override")
|
||||
if isinstance(v2_override, dict):
|
||||
_mask_nested_service_secrets(v2_override)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
def _mask_nested_service_secrets(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_nested_service_secrets(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_nested_service_secrets(item)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ stored, while honouring masked API keys.
|
|||
import copy
|
||||
from typing import Dict
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
MODEL_OVERRIDE_FIELDS,
|
||||
SERVICE_SECRET_FIELDS,
|
||||
|
|
@ -66,9 +66,9 @@ def _merge_service_secret_fields(
|
|||
|
||||
|
||||
def merge_user_configurations(
|
||||
existing: UserConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> UserConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new UserConfiguration.
|
||||
existing: EffectiveAIModelConfiguration, incoming_partial: Dict[str, dict]
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
"""Merge *incoming_partial* onto *existing* and return a new EffectiveAIModelConfiguration.
|
||||
|
||||
*incoming_partial* is the body of the PUT request (already `model_dump()`ed or
|
||||
extracted via Pydantic `model_dump`).
|
||||
|
|
@ -113,14 +113,14 @@ def merge_user_configurations(
|
|||
if "timezone" in incoming_partial:
|
||||
merged["timezone"] = incoming_partial["timezone"]
|
||||
|
||||
# Onboarding gate flags — overwrite only when supplied (set once on submit/skip).
|
||||
# Onboarding gate flags: overwrite only when supplied.
|
||||
if "onboarding_completed_at" in incoming_partial:
|
||||
merged["onboarding_completed_at"] = incoming_partial["onboarding_completed_at"]
|
||||
|
||||
if "onboarding_skipped" in incoming_partial:
|
||||
merged["onboarding_skipped"] = incoming_partial["onboarding_skipped"]
|
||||
|
||||
return UserConfiguration.model_validate(merged)
|
||||
return EffectiveAIModelConfiguration.model_validate(merged)
|
||||
|
||||
|
||||
def merge_workflow_configuration_secrets(
|
||||
|
|
|
|||
|
|
@ -911,7 +911,7 @@ class DograhTTSService(BaseTTSConfiguration):
|
|||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice.")
|
||||
|
||||
|
||||
CARTESIA_TTS_MODELS = ["sonic-3"]
|
||||
CARTESIA_TTS_MODELS = ["sonic-3.5", "sonic-3"]
|
||||
|
||||
|
||||
@register_tts
|
||||
|
|
@ -919,7 +919,7 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
|||
model_config = CARTESIA_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.CARTESIA] = ServiceProviders.CARTESIA
|
||||
model: str = Field(
|
||||
default="sonic-3",
|
||||
default="sonic-3.5",
|
||||
description="Cartesia TTS model.",
|
||||
json_schema_extra={"examples": CARTESIA_TTS_MODELS},
|
||||
)
|
||||
|
|
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
DograhEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceType,
|
||||
)
|
||||
|
||||
# Maps override key → (UserConfiguration field, ServiceType for registry lookup)
|
||||
# Maps override key → (EffectiveAIModelConfiguration field, ServiceType for registry lookup)
|
||||
_SECTION_MAP: dict[str, ServiceType] = {
|
||||
"llm": ServiceType.LLM,
|
||||
"tts": ServiceType.TTS,
|
||||
|
|
@ -36,7 +36,7 @@ _SECRET_FIELDS = ("api_key", "credentials", "aws_access_key", "aws_secret_key")
|
|||
|
||||
def enrich_overrides_with_api_keys(
|
||||
model_overrides: dict,
|
||||
user_config: UserConfiguration,
|
||||
user_config: EffectiveAIModelConfiguration,
|
||||
) -> dict:
|
||||
"""Copy API keys from the global config into model_overrides where missing.
|
||||
|
||||
|
|
@ -74,9 +74,9 @@ def enrich_overrides_with_api_keys(
|
|||
|
||||
|
||||
def resolve_effective_config(
|
||||
user_config: UserConfiguration,
|
||||
user_config: EffectiveAIModelConfiguration,
|
||||
model_overrides: dict | None,
|
||||
) -> UserConfiguration:
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
"""Deep-merge workflow model_overrides onto global user config.
|
||||
|
||||
- If model_overrides is None or empty, returns a copy of user_config unchanged.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue