mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: add config v2 to simplify billing (#428)
* feat: add model config v2 * chore: centralize user org selection * chore: move preferences to platform settings * fix: decouple org preference and ai model preferences
This commit is contained in:
parent
49e68b49d5
commit
cdbd06c8d9
42 changed files with 5135 additions and 264 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException, Query, WebSocket
|
||||
from fastapi import Depends, Header, HTTPException, Query, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -119,6 +119,19 @@ async def get_user(
|
|||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
|
||||
mps_config
|
||||
)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
|
|
@ -129,6 +142,14 @@ async def get_user(
|
|||
return user_model
|
||||
|
||||
|
||||
async def get_user_with_selected_organization(
|
||||
user: Annotated[UserModel, Depends(get_user)],
|
||||
) -> UserModel:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
|
|
|
|||
484
api/services/configuration/ai_model_configuration.py
Normal file
484
api/services/configuration/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
BYOKAIModelConfiguration,
|
||||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
mask_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAIModelConfiguration:
|
||||
effective: EffectiveAIModelConfiguration
|
||||
source: AIModelConfigurationSource
|
||||
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowAIModelConfigurationMigrationResult:
|
||||
workflow_count: int = 0
|
||||
definition_count: int = 0
|
||||
workflow_ids: list[int] | None = None
|
||||
|
||||
|
||||
async def get_resolved_ai_model_configuration(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
) -> ResolvedAIModelConfiguration:
|
||||
organization_configuration = await get_organization_ai_model_configuration_v2(
|
||||
organization_id
|
||||
)
|
||||
if organization_configuration is not None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=compile_ai_model_configuration_v2(organization_configuration),
|
||||
source="organization_v2",
|
||||
organization_configuration=organization_configuration,
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=EffectiveAIModelConfiguration(),
|
||||
source="empty",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user_id)
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=legacy,
|
||||
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
|
||||
)
|
||||
|
||||
|
||||
async def get_effective_ai_model_configuration_for_workflow(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
workflow_configurations: dict | None,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
workflow_configurations = workflow_configurations or {}
|
||||
v2_override = workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
if v2_override:
|
||||
return compile_ai_model_configuration_v2(
|
||||
OrganizationAIModelConfigurationV2.model_validate(v2_override)
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return resolve_effective_config(
|
||||
resolved_config.effective,
|
||||
workflow_configurations.get("model_overrides"),
|
||||
)
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_v2(
|
||||
organization_id: int | None,
|
||||
) -> OrganizationAIModelConfigurationV2 | None:
|
||||
if organization_id is None:
|
||||
return None
|
||||
row = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
)
|
||||
if row is None or not row.value:
|
||||
return None
|
||||
try:
|
||||
return OrganizationAIModelConfigurationV2.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration v2 for organization "
|
||||
f"{organization_id}: {exc}. Falling back to legacy configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_v2(
|
||||
organization_id: int,
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
configuration.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
||||
async def migrate_workflow_model_configurations_to_v2(
|
||||
*,
|
||||
organization_id: int,
|
||||
fallback_user_config: EffectiveAIModelConfiguration,
|
||||
) -> WorkflowAIModelConfigurationMigrationResult:
|
||||
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
|
||||
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
|
||||
workflow_updates: list[tuple[int, dict]] = []
|
||||
definition_updates: list[tuple[int, dict]] = []
|
||||
migrated_workflow_ids: set[int] = set()
|
||||
|
||||
for workflow in workflows:
|
||||
base_config = fallback_user_config
|
||||
if workflow.user_id is not None:
|
||||
if workflow.user_id not in owner_configs:
|
||||
owner_configs[
|
||||
workflow.user_id
|
||||
] = await db_client.get_user_configurations(workflow.user_id)
|
||||
base_config = owner_configs[workflow.user_id]
|
||||
|
||||
workflow_configs, workflow_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if workflow_changed:
|
||||
workflow_updates.append((workflow.id, workflow_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
for definition in workflow.definitions:
|
||||
definition_configs, definition_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
definition.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if definition_changed:
|
||||
definition_updates.append((definition.id, definition_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
if workflow_updates or definition_updates:
|
||||
async with db_client.async_session() as session:
|
||||
for workflow_id, workflow_configs in workflow_updates:
|
||||
await session.execute(
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.values(workflow_configurations=workflow_configs)
|
||||
)
|
||||
for definition_id, definition_configs in definition_updates:
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(WorkflowDefinitionModel.id == definition_id)
|
||||
.values(workflow_configurations=definition_configs)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return WorkflowAIModelConfigurationMigrationResult(
|
||||
workflow_count=len(migrated_workflow_ids),
|
||||
definition_count=len(definition_updates),
|
||||
workflow_ids=sorted(migrated_workflow_ids),
|
||||
)
|
||||
|
||||
|
||||
def migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations: dict | None,
|
||||
base_config: EffectiveAIModelConfiguration,
|
||||
) -> tuple[dict, bool]:
|
||||
if not isinstance(workflow_configurations, dict):
|
||||
return {}, False
|
||||
|
||||
migrated = copy.deepcopy(workflow_configurations)
|
||||
model_overrides = migrated.get("model_overrides")
|
||||
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
|
||||
if not isinstance(model_overrides, dict):
|
||||
if "model_overrides" in migrated:
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
return migrated, False
|
||||
|
||||
if not existing_v2_override:
|
||||
effective = resolve_effective_config(base_config, model_overrides)
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
|
||||
|
||||
def merge_ai_model_configuration_v2_secrets(
|
||||
incoming: OrganizationAIModelConfigurationV2,
|
||||
existing: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
if existing is None:
|
||||
return incoming
|
||||
|
||||
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
|
||||
existing_dict = existing.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
|
||||
incoming_dograh = incoming_dict.get("dograh") or {}
|
||||
existing_dograh = existing_dict.get("dograh") or {}
|
||||
incoming_key = incoming_dograh.get("api_key")
|
||||
existing_key = existing_dograh.get("api_key")
|
||||
if incoming_key and existing_key and contains_masked_key(incoming_key):
|
||||
incoming_dograh["api_key"] = resolve_masked_api_keys(
|
||||
incoming_key,
|
||||
existing_key,
|
||||
)
|
||||
|
||||
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
|
||||
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
|
||||
|
||||
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
|
||||
|
||||
|
||||
def check_for_masked_keys_in_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> None:
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_raise_if_masked_secret(data)
|
||||
|
||||
|
||||
def mask_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> dict | None:
|
||||
if configuration is None:
|
||||
return None
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_mask_secret_fields(data)
|
||||
return data
|
||||
|
||||
|
||||
def convert_legacy_ai_model_configuration_to_v2(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
dograh_key = _first_dograh_api_key(configuration)
|
||||
if dograh_key:
|
||||
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
|
||||
|
||||
if configuration.is_realtime:
|
||||
if configuration.realtime is None or configuration.llm is None:
|
||||
raise ValueError("Realtime legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="realtime",
|
||||
realtime=BYOKRealtimeAIModelConfiguration(
|
||||
realtime=configuration.realtime,
|
||||
llm=configuration.llm,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
configuration.llm is None
|
||||
or configuration.tts is None
|
||||
or configuration.stt is None
|
||||
):
|
||||
raise ValueError("Pipeline legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="pipeline",
|
||||
pipeline=BYOKPipelineAIModelConfiguration(
|
||||
llm=configuration.llm,
|
||||
tts=configuration.tts,
|
||||
stt=configuration.stt,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
def apply_managed_embeddings_base_url(
|
||||
*,
|
||||
provider: str | None,
|
||||
base_url: str | None,
|
||||
) -> str | None:
|
||||
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
|
||||
return dograh_embeddings_base_url()
|
||||
return base_url
|
||||
|
||||
|
||||
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
|
||||
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
|
||||
return
|
||||
incoming_mode = incoming_byok.get("mode")
|
||||
existing_mode = existing_byok.get("mode")
|
||||
if incoming_mode != existing_mode:
|
||||
return
|
||||
section_names = (
|
||||
("llm", "tts", "stt", "embeddings")
|
||||
if incoming_mode == "pipeline"
|
||||
else ("realtime", "llm", "embeddings")
|
||||
)
|
||||
incoming_container = incoming_byok.get(incoming_mode)
|
||||
existing_container = existing_byok.get(existing_mode)
|
||||
if not isinstance(incoming_container, dict) or not isinstance(
|
||||
existing_container, dict
|
||||
):
|
||||
return
|
||||
for section_name in section_names:
|
||||
incoming_section = incoming_container.get(section_name)
|
||||
existing_section = existing_container.get(section_name)
|
||||
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
|
||||
_merge_service_secret_fields(incoming_section, existing_section)
|
||||
|
||||
|
||||
async def _list_workflows_for_model_configuration_migration(
|
||||
organization_id: int,
|
||||
) -> list[WorkflowModel]:
|
||||
async with db_client.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.definitions))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
return list(result.scalars().unique().all())
|
||||
|
||||
|
||||
def _merge_service_secret_fields(incoming: dict, existing: dict):
|
||||
if (
|
||||
incoming.get("provider") is not None
|
||||
and existing.get("provider") is not None
|
||||
and incoming.get("provider") != existing.get("provider")
|
||||
):
|
||||
return
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing:
|
||||
continue
|
||||
incoming_secret = incoming.get(secret_field)
|
||||
existing_secret = existing[secret_field]
|
||||
if incoming_secret is None:
|
||||
incoming[secret_field] = existing_secret
|
||||
elif contains_masked_key(incoming_secret):
|
||||
incoming[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_masked_secret(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
|
||||
raise ValueError(
|
||||
f"The {key} appears to be masked. Please provide the actual "
|
||||
"value, not the masked value."
|
||||
)
|
||||
_raise_if_masked_secret(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_raise_if_masked_secret(item)
|
||||
|
||||
|
||||
def _mask_secret_fields(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_secret_fields(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_secret_fields(item)
|
||||
|
||||
|
||||
def _mask_secret_value(value):
|
||||
if isinstance(value, list):
|
||||
return [mask_key(item) for item in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
|
||||
return any(
|
||||
service is not None
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_any_dograh_legacy_configuration(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
dograh_key: str,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
speed = getattr(configuration.tts, "speed", 1.0)
|
||||
if speed not in DOGRAH_SPEED_OPTIONS:
|
||||
speed = 1.0
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key=dograh_key,
|
||||
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
|
||||
or DOGRAH_DEFAULT_VOICE,
|
||||
speed=speed,
|
||||
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
|
||||
or DOGRAH_DEFAULT_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
):
|
||||
if service is None or _provider(service) != ServiceProviders.DOGRAH:
|
||||
continue
|
||||
try:
|
||||
return _single_api_key(service)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _provider(service):
|
||||
return getattr(service, "provider", None)
|
||||
|
||||
|
||||
def _single_api_key(service) -> str:
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Expected exactly one API key")
|
||||
return keys[0]
|
||||
key = getattr(service, "api_key", None)
|
||||
if not key:
|
||||
raise ValueError("Expected an API key")
|
||||
return key
|
||||
|
|
@ -151,21 +151,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
|||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
if isinstance(model_overrides, dict):
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
v2_override = masked.get("model_configuration_v2_override")
|
||||
if isinstance(v2_override, dict):
|
||||
_mask_nested_service_secrets(v2_override)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
def _mask_nested_service_secrets(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_nested_service_secrets(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_nested_service_secrets(item)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
DograhEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
62
api/services/organization_preferences.py
Normal file
62
api/services/organization_preferences.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from inspect import isawaitable
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
|
||||
|
||||
async def get_organization_preferences(
|
||||
organization_id: int | None,
|
||||
db=None,
|
||||
) -> OrganizationPreferences:
|
||||
if organization_id is None:
|
||||
return OrganizationPreferences()
|
||||
|
||||
db = db or db_client
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
)
|
||||
if row is None:
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
)
|
||||
return _parse_preferences(row.value if row is not None else None, organization_id)
|
||||
|
||||
|
||||
async def upsert_organization_preferences(
|
||||
organization_id: int,
|
||||
preferences: OrganizationPreferences,
|
||||
) -> OrganizationPreferences:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
preferences.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return preferences
|
||||
|
||||
|
||||
async def _get_configuration(db, organization_id: int, key: str):
|
||||
row = db.get_configuration(organization_id, key)
|
||||
if isawaitable(row):
|
||||
row = await row
|
||||
return row
|
||||
|
||||
|
||||
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
|
||||
if not value or not isinstance(value, dict):
|
||||
return OrganizationPreferences()
|
||||
try:
|
||||
return OrganizationPreferences.model_validate(value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid organization preferences for organization "
|
||||
f"{organization_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OrganizationPreferences()
|
||||
|
|
@ -195,14 +195,17 @@ async def run_pipeline_telephony(
|
|||
# Resolve effective user config here so the transport can tune its
|
||||
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
|
||||
# values into _run_pipeline so it doesn't fetch them again.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -272,15 +275,18 @@ async def run_pipeline_smallwebrtc(
|
|||
# Resolve workflow_run + effective user_config here so the transport can
|
||||
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
|
||||
# reuses these via kwargs so we don't fetch twice.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -380,11 +386,14 @@ async def _run_pipeline(
|
|||
# Resolve model overrides from the version onto global user config (skip
|
||||
# when the caller already resolved it).
|
||||
if resolved_user_config is None:
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
else:
|
||||
user_config = resolved_user_config
|
||||
|
|
@ -508,10 +517,17 @@ async def _run_pipeline(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if user_config and user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
|
|
@ -48,17 +50,20 @@ async def check_dograh_quota(
|
|||
if quota is insufficient.
|
||||
"""
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
organization_id = user.selected_organization_id
|
||||
workflow_configurations = None
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
model_overrides = (workflow.workflow_configurations or {}).get(
|
||||
"model_overrides"
|
||||
)
|
||||
if model_overrides:
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
organization_id = workflow.organization_id
|
||||
workflow_configurations = workflow.workflow_configurations
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
|
|
@ -76,6 +81,13 @@ async def check_dograh_quota(
|
|||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
if (
|
||||
user_config.embeddings
|
||||
and user_config.embeddings.provider == ServiceProviders.DOGRAH
|
||||
):
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.embeddings.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
|
@ -84,7 +96,9 @@ async def check_dograh_quota(
|
|||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
|
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
|
|||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_configurations = {}
|
||||
if workflow_run.definition:
|
||||
workflow_configurations = (
|
||||
workflow_run.definition.workflow_configurations or {}
|
||||
)
|
||||
elif workflow_run.workflow:
|
||||
workflow_configurations = (
|
||||
workflow_run.workflow.workflow_configurations or {}
|
||||
)
|
||||
|
||||
user_configuration = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow_run.workflow.organization_id
|
||||
if workflow_run.workflow
|
||||
else None,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
|
|
@ -410,9 +409,14 @@ async def execute_text_chat_pending_turn(
|
|||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_run.workflow.user.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
|
@ -466,9 +470,17 @@ async def execute_text_chat_pending_turn(
|
|||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue