mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
feat: add config v2 to simplify billing (#428)
* feat: add model config v2 * chore: centralize user org selection * chore: move preferences to platform settings * fix: decouple org preference and ai model preferences
This commit is contained in:
parent
49e68b49d5
commit
cdbd06c8d9
42 changed files with 5135 additions and 264 deletions
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import joinedload
|
|||
from api.db.base_client import BaseDBClient
|
||||
from api.db.filters import apply_workflow_run_filters
|
||||
from api.db.models import (
|
||||
OrganizationConfigurationModel,
|
||||
OrganizationModel,
|
||||
OrganizationUsageCycleModel,
|
||||
UserConfigurationModel,
|
||||
|
|
@ -17,6 +18,7 @@ from api.db.models import (
|
|||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
|
||||
|
||||
|
|
@ -440,8 +442,29 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"""Get daily usage breakdown for an organization with pricing."""
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Get user timezone if user_id is provided
|
||||
# Get org timezone preference first, then fall back to legacy user config.
|
||||
user_timezone = "UTC" # Default timezone
|
||||
pref_result = await session.execute(
|
||||
select(OrganizationConfigurationModel).where(
|
||||
OrganizationConfigurationModel.organization_id == organization_id,
|
||||
OrganizationConfigurationModel.key.in_(
|
||||
[
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
pref_rows = pref_result.scalars().all()
|
||||
pref_by_key = {pref.key: pref for pref in pref_rows}
|
||||
pref_obj = pref_by_key.get(
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value
|
||||
) or pref_by_key.get(
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value
|
||||
)
|
||||
if pref_obj and pref_obj.value:
|
||||
user_timezone = pref_obj.value.get("timezone") or user_timezone
|
||||
|
||||
if user_id:
|
||||
config_result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
|
|
@ -453,7 +476,7 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
user_config = UserConfiguration.model_validate(
|
||||
config_obj.configuration
|
||||
)
|
||||
if user_config.timezone:
|
||||
if user_config.timezone and user_timezone == "UTC":
|
||||
user_timezone = user_config.timezone
|
||||
|
||||
# Validate timezone string
|
||||
|
|
|
|||
|
|
@ -89,6 +89,11 @@ class OrganizationConfigurationKey(Enum):
|
|||
LANGFUSE_CREDENTIALS = (
|
||||
"LANGFUSE_CREDENTIALS" # Org-level Langfuse tracing credentials
|
||||
)
|
||||
MODEL_CONFIGURATION_V2 = (
|
||||
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
|
||||
)
|
||||
ORGANIZATION_PREFERENCES = "ORGANIZATION_PREFERENCES" # Org-level defaults such as timezone/test call number
|
||||
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
|
||||
|
||||
|
||||
class WorkflowStatus(Enum):
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
|
||||
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import create_jwt_token, hash_password, verify_password
|
||||
|
||||
|
|
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
|
|||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(user.id, mps_config)
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create default configuration for OSS user", exc_info=True
|
||||
|
|
|
|||
|
|
@ -369,6 +369,10 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import (
|
||||
AzureOpenAIEmbeddingService,
|
||||
|
|
@ -376,10 +380,15 @@ async def search_chunks(
|
|||
)
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
|
|
@ -388,6 +397,10 @@ async def search_chunks(
|
|||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
)
|
||||
|
|
@ -406,9 +419,7 @@ async def search_chunks(
|
|||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -10,6 +10,14 @@ from api.db import db_client
|
|||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
)
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
from api.schemas.telephony_config import (
|
||||
TelephonyConfigRequest,
|
||||
TelephonyConfigurationCreateRequest,
|
||||
|
|
@ -26,8 +34,31 @@ from api.schemas.telephony_phone_number import (
|
|||
PhoneNumberUpdateRequest,
|
||||
ProviderSyncStatus,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import is_mask_of, mask_key
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_organization_ai_model_configuration_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_model_configurations_to_v2,
|
||||
upsert_organization_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.registry import (
|
||||
DOGRAH_STT_LANGUAGES,
|
||||
REGISTRY,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
from api.services.telephony.factory import get_telephony_provider_by_id
|
||||
|
|
@ -159,6 +190,222 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI model configurations v2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[service_type].items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
|
||||
|
||||
async def _model_configuration_v2_response(
|
||||
*,
|
||||
user: UserModel,
|
||||
configuration: OrganizationAIModelConfigurationV2 | None = None,
|
||||
) -> OrganizationAIModelConfigurationResponse:
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
raw_configuration = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else resolved.organization_configuration
|
||||
)
|
||||
return OrganizationAIModelConfigurationResponse(
|
||||
configuration=mask_ai_model_configuration_v2(raw_configuration),
|
||||
effective_configuration=mask_user_config(resolved.effective),
|
||||
source=resolved.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/defaults")
|
||||
async def get_model_configuration_v2_defaults(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
byok_default_providers = {
|
||||
service: provider
|
||||
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
return {
|
||||
"dograh": {
|
||||
"voices": [DOGRAH_DEFAULT_VOICE],
|
||||
"speeds": list(DOGRAH_SPEED_OPTIONS),
|
||||
"languages": DOGRAH_STT_LANGUAGES,
|
||||
"defaults": {
|
||||
"voice": DOGRAH_DEFAULT_VOICE,
|
||||
"speed": 1.0,
|
||||
"language": DOGRAH_DEFAULT_LANGUAGE,
|
||||
},
|
||||
},
|
||||
"byok": {
|
||||
"pipeline": {
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"tts": _byok_provider_schemas(ServiceType.TTS),
|
||||
"stt": _byok_provider_schemas(ServiceType.STT),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
"realtime": {
|
||||
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def get_model_configuration_v2(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await _model_configuration_v2_response(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def save_model_configuration_v2(
|
||||
request: OrganizationAIModelConfigurationV2,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
|
||||
try:
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/migration-preview")
|
||||
async def preview_model_configuration_v2_migration(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
return {
|
||||
"configuration": mask_ai_model_configuration_v2(configuration),
|
||||
"effective_configuration": mask_user_config(
|
||||
compile_ai_model_configuration_v2(configuration)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-configurations/v2/migrate",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def migrate_model_configuration_v2(
|
||||
force: bool = Query(default=False),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
if existing is not None and not force:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Organization already has a v2 model configuration",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
await migrate_workflow_model_configurations_to_v2(
|
||||
organization_id=organization_id,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/preferences", response_model=OrganizationPreferences)
|
||||
async def get_preferences(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await get_organization_preferences(organization_id)
|
||||
|
||||
|
||||
@router.put("/preferences", response_model=OrganizationPreferences)
|
||||
async def save_preferences(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await upsert_organization_preferences(
|
||||
organization_id,
|
||||
request,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_model_configuration_preferences_legacy(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await get_preferences(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def save_model_configuration_preferences_legacy(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await save_preferences(request=request, user=user)
|
||||
|
||||
|
||||
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
|
||||
"""If the client re-submitted a masked sensitive field, restore the original."""
|
||||
for field_name in _sensitive_fields(provider):
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
|
|||
workflow_run_id: int | None = None
|
||||
phone_number: str | None = None
|
||||
# Optional explicit telephony config to use for the test call. If omitted,
|
||||
# falls back to the user's per-user default (when set), then the org default.
|
||||
# falls back to the org default.
|
||||
telephony_configuration_id: int | None = None
|
||||
# Optional caller-ID phone number to dial out from. Must belong to the
|
||||
# resolved telephony configuration; otherwise the provider picks one.
|
||||
|
|
@ -82,7 +82,12 @@ async def initiate_call(
|
|||
"""Initiate a call using the configured telephony provider from web browser. This is
|
||||
supposed to be a test call method for the draft version of the agent."""
|
||||
|
||||
user_configuration = await db_client.get_user_configurations(user.id)
|
||||
from api.services.organization_preferences import get_organization_preferences
|
||||
|
||||
preferences = await get_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
db=db_client,
|
||||
)
|
||||
|
||||
# Resolve which telephony config to use: explicit request value, otherwise
|
||||
# the org's default outbound config.
|
||||
|
|
@ -116,13 +121,12 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
phone_number = request.phone_number or preferences.test_phone_number
|
||||
|
||||
if not phone_number:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Phone number must be provided in request or set in user "
|
||||
"configuration",
|
||||
detail="Phone number must be provided in request or set in organization preferences",
|
||||
)
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from api.db.models import (
|
|||
UserModel,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.check_validity import (
|
||||
APIKeyStatusResponse,
|
||||
UserConfigurationValidator,
|
||||
|
|
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
|
|||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import REGISTRY, ServiceType
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
|
||||
|
|
@ -91,8 +98,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
async def get_user_configurations(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> UserConfigurationRequestResponseSchema:
|
||||
user_configurations = await db_client.get_user_configurations(user.id)
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
masked_config = mask_user_config(resolved_config.effective)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -118,34 +134,61 @@ async def update_user_configurations(
|
|||
|
||||
# Remove organization_pricing from incoming dict as it's read-only
|
||||
incoming_dict.pop("organization_pricing", None)
|
||||
preferences_update = {
|
||||
key: incoming_dict.pop(key)
|
||||
for key in ("test_phone_number", "timezone")
|
||||
if key in incoming_dict
|
||||
}
|
||||
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(existing_config, incoming_dict)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
if incoming_dict:
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(
|
||||
existing_config, incoming_dict
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
else:
|
||||
user_configurations = existing_config
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
if user.selected_organization_id and preferences_update:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if "test_phone_number" in preferences_update:
|
||||
preferences.test_phone_number = preferences_update["test_phone_number"]
|
||||
if "timezone" in preferences_update:
|
||||
preferences.timezone = preferences_update["timezone"]
|
||||
await upsert_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
preferences,
|
||||
)
|
||||
|
||||
# Return masked version of updated config
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -165,7 +208,11 @@ async def validate_user_configurations(
|
|||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> APIKeyStatusResponse:
|
||||
configurations = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
configurations = resolved_config.effective
|
||||
|
||||
if (
|
||||
configurations.last_validated_at
|
||||
|
|
|
|||
|
|
@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
|
|||
from api.db.models import UserModel
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_configurations,
|
||||
|
|
@ -955,12 +964,74 @@ async def update_workflow(
|
|||
existing_def,
|
||||
)
|
||||
|
||||
# Validate model_overrides: resolve onto global config, then
|
||||
# run the same validator used by the user-configurations endpoint.
|
||||
# Also stamp the current global API key into the override so the override
|
||||
# remains functional if the global config later switches to a different provider.
|
||||
# Validate model overrides. v2 uses a complete workflow-level model
|
||||
# configuration; legacy v1 uses partial service overlays.
|
||||
workflow_configurations = request.workflow_configurations
|
||||
if workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
if workflow_configurations and workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
existing_draft = await db_client.get_draft_version(workflow_id)
|
||||
existing_configs = (
|
||||
existing_draft.workflow_configurations
|
||||
if existing_draft
|
||||
else existing_workflow.released_definition.workflow_configurations
|
||||
)
|
||||
existing_v2_override = (existing_configs or {}).get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
try:
|
||||
incoming_v2_override = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
workflow_configurations[
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
]
|
||||
)
|
||||
)
|
||||
existing_v2_override_config = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
existing_v2_override
|
||||
)
|
||||
if existing_v2_override
|
||||
else None
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
incoming_v2_override,
|
||||
existing_v2_override_config,
|
||||
)
|
||||
if existing_v2_override_config is None:
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
v2_override,
|
||||
resolved_config.organization_configuration,
|
||||
)
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
|
||||
effective = compile_ai_model_configuration_v2(v2_override)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except (ValidationError, ValueError) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
elif workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
@ -978,24 +1049,46 @@ async def update_workflow(
|
|||
workflow_configurations,
|
||||
existing_configs,
|
||||
)
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
compile_ai_model_configuration_v2(v2_override),
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
else:
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
if resolved_config.source == "organization_v2":
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
else:
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
|
||||
# Reject upfront if any new trigger path collides with another
|
||||
# workflow's trigger — keeps the workflow record from
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user_with_selected_organization
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
|
|
@ -96,12 +96,6 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _require_selected_organization_id(user: UserModel) -> int:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
if not quota_result.has_quota:
|
||||
|
|
@ -114,9 +108,8 @@ async def _load_text_session_or_404(
|
|||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=organization_id
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
|
|
@ -158,9 +151,8 @@ async def _execute_pending_turn_response(
|
|||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
|
|
@ -172,7 +164,7 @@ async def create_text_chat_session(
|
|||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
organization_id=organization_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
|
@ -220,7 +212,7 @@ async def create_text_chat_session(
|
|||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
|
@ -234,7 +226,7 @@ async def append_text_chat_message(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
|
@ -264,7 +256,7 @@ async def rewind_text_chat_session(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
try:
|
||||
|
|
|
|||
170
api/schemas/ai_model_configuration.py
Normal file
170
api/schemas/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.registry import (
|
||||
DograhEmbeddingsConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
EmbeddingsConfig,
|
||||
LLMConfig,
|
||||
RealtimeConfig,
|
||||
ServiceProviders,
|
||||
STTConfig,
|
||||
TTSConfig,
|
||||
)
|
||||
|
||||
DOGRAH_SPEED_OPTIONS: tuple[float, ...] = (0.8, 1.0, 1.2)
|
||||
DOGRAH_DEFAULT_VOICE = "default"
|
||||
DOGRAH_DEFAULT_LANGUAGE = "multi"
|
||||
|
||||
|
||||
class DograhManagedAIModelConfiguration(BaseModel):
|
||||
api_key: str
|
||||
voice: str = DOGRAH_DEFAULT_VOICE
|
||||
speed: float = Field(default=1.0)
|
||||
language: str = DOGRAH_DEFAULT_LANGUAGE
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_speed(self):
|
||||
if self.speed not in DOGRAH_SPEED_OPTIONS:
|
||||
allowed = ", ".join(str(speed) for speed in DOGRAH_SPEED_OPTIONS)
|
||||
raise ValueError(f"Dograh speed must be one of: {allowed}")
|
||||
return self
|
||||
|
||||
|
||||
class BYOKPipelineAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig
|
||||
tts: TTSConfig
|
||||
stt: STTConfig
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reject_dograh_providers(self):
|
||||
_reject_dograh_provider("llm", self.llm)
|
||||
_reject_dograh_provider("tts", self.tts)
|
||||
_reject_dograh_provider("stt", self.stt)
|
||||
_reject_dograh_provider("embeddings", self.embeddings)
|
||||
return self
|
||||
|
||||
|
||||
class BYOKRealtimeAIModelConfiguration(BaseModel):
|
||||
realtime: RealtimeConfig
|
||||
llm: LLMConfig
|
||||
embeddings: EmbeddingsConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reject_dograh_providers(self):
|
||||
_reject_dograh_provider("llm", self.llm)
|
||||
_reject_dograh_provider("embeddings", self.embeddings)
|
||||
return self
|
||||
|
||||
|
||||
class BYOKAIModelConfiguration(BaseModel):
|
||||
mode: Literal["pipeline", "realtime"]
|
||||
pipeline: BYOKPipelineAIModelConfiguration | None = None
|
||||
realtime: BYOKRealtimeAIModelConfiguration | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selected_mode(self):
|
||||
if self.mode == "pipeline" and self.pipeline is None:
|
||||
raise ValueError("byok.pipeline is required when byok.mode is pipeline")
|
||||
if self.mode == "realtime" and self.realtime is None:
|
||||
raise ValueError("byok.realtime is required when byok.mode is realtime")
|
||||
return self
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationV2(BaseModel):
|
||||
version: Literal[2] = 2
|
||||
mode: Literal["dograh", "byok"]
|
||||
dograh: DograhManagedAIModelConfiguration | None = None
|
||||
byok: BYOKAIModelConfiguration | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_selected_mode(self):
|
||||
if self.mode == "dograh" and self.dograh is None:
|
||||
raise ValueError("dograh configuration is required when mode is dograh")
|
||||
if self.mode == "byok" and self.byok is None:
|
||||
raise ValueError("byok configuration is required when mode is byok")
|
||||
return self
|
||||
|
||||
|
||||
class OrganizationAIModelConfigurationResponse(BaseModel):
|
||||
configuration: dict | None
|
||||
effective_configuration: dict
|
||||
source: Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
|
||||
|
||||
def compile_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
if configuration.mode == "dograh":
|
||||
if configuration.dograh is None:
|
||||
raise ValueError("dograh configuration is required")
|
||||
return _compile_dograh_configuration(configuration.dograh)
|
||||
|
||||
if configuration.byok is None:
|
||||
raise ValueError("byok configuration is required")
|
||||
if configuration.byok.mode == "pipeline":
|
||||
if configuration.byok.pipeline is None:
|
||||
raise ValueError("byok.pipeline is required")
|
||||
pipeline = configuration.byok.pipeline
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=pipeline.llm,
|
||||
tts=pipeline.tts,
|
||||
stt=pipeline.stt,
|
||||
embeddings=pipeline.embeddings,
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
if configuration.byok.realtime is None:
|
||||
raise ValueError("byok.realtime is required")
|
||||
realtime = configuration.byok.realtime
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=realtime.llm,
|
||||
realtime=realtime.realtime,
|
||||
embeddings=realtime.embeddings,
|
||||
is_realtime=True,
|
||||
)
|
||||
|
||||
|
||||
def _compile_dograh_configuration(
|
||||
configuration: DograhManagedAIModelConfiguration,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
return EffectiveAIModelConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
voice=configuration.voice,
|
||||
speed=configuration.speed,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
language=configuration.language,
|
||||
),
|
||||
embeddings=DograhEmbeddingsConfiguration(
|
||||
provider=ServiceProviders.DOGRAH,
|
||||
api_key=configuration.api_key,
|
||||
model="default",
|
||||
),
|
||||
is_realtime=False,
|
||||
)
|
||||
|
||||
|
||||
def _reject_dograh_provider(section: str, service) -> None:
|
||||
if service is None:
|
||||
return
|
||||
if getattr(service, "provider", None) == ServiceProviders.DOGRAH:
|
||||
raise ValueError(f"BYOK {section} cannot use Dograh provider")
|
||||
6
api/schemas/organization_preferences.py
Normal file
6
api/schemas/organization_preferences.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OrganizationPreferences(BaseModel):
|
||||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
|
|
@ -11,7 +11,7 @@ from api.services.configuration.registry import (
|
|||
)
|
||||
|
||||
|
||||
class UserConfiguration(BaseModel):
|
||||
class EffectiveAIModelConfiguration(BaseModel):
|
||||
llm: LLMConfig | None = None
|
||||
stt: STTConfig | None = None
|
||||
tts: TTSConfig | None = None
|
||||
|
|
@ -31,3 +31,7 @@ class UserConfiguration(BaseModel):
|
|||
if isinstance(realtime, dict) and not realtime.get("api_key"):
|
||||
data.pop("realtime", None)
|
||||
return data
|
||||
|
||||
|
||||
# Backward-compatible alias for legacy persistence and existing call sites.
|
||||
UserConfiguration = EffectiveAIModelConfiguration
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException, Query, WebSocket
|
||||
from fastapi import Depends, Header, HTTPException, Query, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -119,6 +119,19 @@ async def get_user(
|
|||
await db_client.update_user_configuration(
|
||||
user_model.id, mps_config
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(
|
||||
mps_config
|
||||
)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
|
|
@ -129,6 +142,14 @@ async def get_user(
|
|||
return user_model
|
||||
|
||||
|
||||
async def get_user_with_selected_organization(
|
||||
user: Annotated[UserModel, Depends(get_user)],
|
||||
) -> UserModel:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
|
|
|
|||
484
api/services/configuration/ai_model_configuration.py
Normal file
484
api/services/configuration/ai_model_configuration.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowDefinitionModel, WorkflowModel
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
BYOKAIModelConfiguration,
|
||||
BYOKPipelineAIModelConfiguration,
|
||||
BYOKRealtimeAIModelConfiguration,
|
||||
DograhManagedAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import EffectiveAIModelConfiguration
|
||||
from api.services.configuration.masking import (
|
||||
SERVICE_SECRET_FIELDS,
|
||||
contains_masked_key,
|
||||
mask_key,
|
||||
resolve_masked_api_keys,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
AIModelConfigurationSource = Literal["organization_v2", "legacy_user_v1", "empty"]
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY = "model_configuration_v2_override"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAIModelConfiguration:
|
||||
effective: EffectiveAIModelConfiguration
|
||||
source: AIModelConfigurationSource
|
||||
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowAIModelConfigurationMigrationResult:
|
||||
workflow_count: int = 0
|
||||
definition_count: int = 0
|
||||
workflow_ids: list[int] | None = None
|
||||
|
||||
|
||||
async def get_resolved_ai_model_configuration(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
) -> ResolvedAIModelConfiguration:
|
||||
organization_configuration = await get_organization_ai_model_configuration_v2(
|
||||
organization_id
|
||||
)
|
||||
if organization_configuration is not None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=compile_ai_model_configuration_v2(organization_configuration),
|
||||
source="organization_v2",
|
||||
organization_configuration=organization_configuration,
|
||||
)
|
||||
|
||||
if user_id is None:
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=EffectiveAIModelConfiguration(),
|
||||
source="empty",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user_id)
|
||||
return ResolvedAIModelConfiguration(
|
||||
effective=legacy,
|
||||
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
|
||||
)
|
||||
|
||||
|
||||
async def get_effective_ai_model_configuration_for_workflow(
|
||||
*,
|
||||
user_id: int | None,
|
||||
organization_id: int | None,
|
||||
workflow_configurations: dict | None,
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
workflow_configurations = workflow_configurations or {}
|
||||
v2_override = workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
if v2_override:
|
||||
return compile_ai_model_configuration_v2(
|
||||
OrganizationAIModelConfigurationV2.model_validate(v2_override)
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return resolve_effective_config(
|
||||
resolved_config.effective,
|
||||
workflow_configurations.get("model_overrides"),
|
||||
)
|
||||
|
||||
|
||||
async def get_organization_ai_model_configuration_v2(
|
||||
organization_id: int | None,
|
||||
) -> OrganizationAIModelConfigurationV2 | None:
|
||||
if organization_id is None:
|
||||
return None
|
||||
row = await db_client.get_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
)
|
||||
if row is None or not row.value:
|
||||
return None
|
||||
try:
|
||||
return OrganizationAIModelConfigurationV2.model_validate(row.value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid org AI model configuration v2 for organization "
|
||||
f"{organization_id}: {exc}. Falling back to legacy configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def upsert_organization_ai_model_configuration_v2(
|
||||
organization_id: int,
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
configuration.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return configuration
|
||||
|
||||
|
||||
async def migrate_workflow_model_configurations_to_v2(
|
||||
*,
|
||||
organization_id: int,
|
||||
fallback_user_config: EffectiveAIModelConfiguration,
|
||||
) -> WorkflowAIModelConfigurationMigrationResult:
|
||||
workflows = await _list_workflows_for_model_configuration_migration(organization_id)
|
||||
owner_configs: dict[int, EffectiveAIModelConfiguration] = {}
|
||||
workflow_updates: list[tuple[int, dict]] = []
|
||||
definition_updates: list[tuple[int, dict]] = []
|
||||
migrated_workflow_ids: set[int] = set()
|
||||
|
||||
for workflow in workflows:
|
||||
base_config = fallback_user_config
|
||||
if workflow.user_id is not None:
|
||||
if workflow.user_id not in owner_configs:
|
||||
owner_configs[
|
||||
workflow.user_id
|
||||
] = await db_client.get_user_configurations(workflow.user_id)
|
||||
base_config = owner_configs[workflow.user_id]
|
||||
|
||||
workflow_configs, workflow_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if workflow_changed:
|
||||
workflow_updates.append((workflow.id, workflow_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
for definition in workflow.definitions:
|
||||
definition_configs, definition_changed = (
|
||||
migrate_workflow_configuration_model_override_to_v2(
|
||||
definition.workflow_configurations,
|
||||
base_config,
|
||||
)
|
||||
)
|
||||
if definition_changed:
|
||||
definition_updates.append((definition.id, definition_configs))
|
||||
migrated_workflow_ids.add(workflow.id)
|
||||
|
||||
if workflow_updates or definition_updates:
|
||||
async with db_client.async_session() as session:
|
||||
for workflow_id, workflow_configs in workflow_updates:
|
||||
await session.execute(
|
||||
update(WorkflowModel)
|
||||
.where(WorkflowModel.id == workflow_id)
|
||||
.values(workflow_configurations=workflow_configs)
|
||||
)
|
||||
for definition_id, definition_configs in definition_updates:
|
||||
await session.execute(
|
||||
update(WorkflowDefinitionModel)
|
||||
.where(WorkflowDefinitionModel.id == definition_id)
|
||||
.values(workflow_configurations=definition_configs)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return WorkflowAIModelConfigurationMigrationResult(
|
||||
workflow_count=len(migrated_workflow_ids),
|
||||
definition_count=len(definition_updates),
|
||||
workflow_ids=sorted(migrated_workflow_ids),
|
||||
)
|
||||
|
||||
|
||||
def migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations: dict | None,
|
||||
base_config: EffectiveAIModelConfiguration,
|
||||
) -> tuple[dict, bool]:
|
||||
if not isinstance(workflow_configurations, dict):
|
||||
return {}, False
|
||||
|
||||
migrated = copy.deepcopy(workflow_configurations)
|
||||
model_overrides = migrated.get("model_overrides")
|
||||
existing_v2_override = migrated.get(WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY)
|
||||
if not isinstance(model_overrides, dict):
|
||||
if "model_overrides" in migrated:
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
return migrated, False
|
||||
|
||||
if not existing_v2_override:
|
||||
effective = resolve_effective_config(base_config, model_overrides)
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY] = v2_override.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
migrated.pop("model_overrides", None)
|
||||
return migrated, True
|
||||
|
||||
|
||||
def merge_ai_model_configuration_v2_secrets(
|
||||
incoming: OrganizationAIModelConfigurationV2,
|
||||
existing: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
if existing is None:
|
||||
return incoming
|
||||
|
||||
incoming_dict = incoming.model_dump(mode="json", exclude_none=True)
|
||||
existing_dict = existing.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
if incoming_dict.get("mode") == "dograh" and existing_dict.get("mode") == "dograh":
|
||||
incoming_dograh = incoming_dict.get("dograh") or {}
|
||||
existing_dograh = existing_dict.get("dograh") or {}
|
||||
incoming_key = incoming_dograh.get("api_key")
|
||||
existing_key = existing_dograh.get("api_key")
|
||||
if incoming_key and existing_key and contains_masked_key(incoming_key):
|
||||
incoming_dograh["api_key"] = resolve_masked_api_keys(
|
||||
incoming_key,
|
||||
existing_key,
|
||||
)
|
||||
|
||||
if incoming_dict.get("mode") == "byok" and existing_dict.get("mode") == "byok":
|
||||
_merge_byok_secret_fields(incoming_dict.get("byok"), existing_dict.get("byok"))
|
||||
|
||||
return OrganizationAIModelConfigurationV2.model_validate(incoming_dict)
|
||||
|
||||
|
||||
def check_for_masked_keys_in_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2,
|
||||
) -> None:
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_raise_if_masked_secret(data)
|
||||
|
||||
|
||||
def mask_ai_model_configuration_v2(
|
||||
configuration: OrganizationAIModelConfigurationV2 | None,
|
||||
) -> dict | None:
|
||||
if configuration is None:
|
||||
return None
|
||||
data = configuration.model_dump(mode="json", exclude_none=True)
|
||||
_mask_secret_fields(data)
|
||||
return data
|
||||
|
||||
|
||||
def convert_legacy_ai_model_configuration_to_v2(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
dograh_key = _first_dograh_api_key(configuration)
|
||||
if dograh_key:
|
||||
return _convert_any_dograh_legacy_configuration(configuration, dograh_key)
|
||||
|
||||
if configuration.is_realtime:
|
||||
if configuration.realtime is None or configuration.llm is None:
|
||||
raise ValueError("Realtime legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="realtime",
|
||||
realtime=BYOKRealtimeAIModelConfiguration(
|
||||
realtime=configuration.realtime,
|
||||
llm=configuration.llm,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
configuration.llm is None
|
||||
or configuration.tts is None
|
||||
or configuration.stt is None
|
||||
):
|
||||
raise ValueError("Pipeline legacy configuration is incomplete")
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok=BYOKAIModelConfiguration(
|
||||
mode="pipeline",
|
||||
pipeline=BYOKPipelineAIModelConfiguration(
|
||||
llm=configuration.llm,
|
||||
tts=configuration.tts,
|
||||
stt=configuration.stt,
|
||||
embeddings=configuration.embeddings,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def dograh_embeddings_base_url() -> str:
|
||||
return f"{MPS_API_URL}/api/v1/llm"
|
||||
|
||||
|
||||
def apply_managed_embeddings_base_url(
|
||||
*,
|
||||
provider: str | None,
|
||||
base_url: str | None,
|
||||
) -> str | None:
|
||||
if provider == ServiceProviders.DOGRAH.value or provider == ServiceProviders.DOGRAH:
|
||||
return dograh_embeddings_base_url()
|
||||
return base_url
|
||||
|
||||
|
||||
def _merge_byok_secret_fields(incoming_byok: dict | None, existing_byok: dict | None):
|
||||
if not isinstance(incoming_byok, dict) or not isinstance(existing_byok, dict):
|
||||
return
|
||||
incoming_mode = incoming_byok.get("mode")
|
||||
existing_mode = existing_byok.get("mode")
|
||||
if incoming_mode != existing_mode:
|
||||
return
|
||||
section_names = (
|
||||
("llm", "tts", "stt", "embeddings")
|
||||
if incoming_mode == "pipeline"
|
||||
else ("realtime", "llm", "embeddings")
|
||||
)
|
||||
incoming_container = incoming_byok.get(incoming_mode)
|
||||
existing_container = existing_byok.get(existing_mode)
|
||||
if not isinstance(incoming_container, dict) or not isinstance(
|
||||
existing_container, dict
|
||||
):
|
||||
return
|
||||
for section_name in section_names:
|
||||
incoming_section = incoming_container.get(section_name)
|
||||
existing_section = existing_container.get(section_name)
|
||||
if isinstance(incoming_section, dict) and isinstance(existing_section, dict):
|
||||
_merge_service_secret_fields(incoming_section, existing_section)
|
||||
|
||||
|
||||
async def _list_workflows_for_model_configuration_migration(
|
||||
organization_id: int,
|
||||
) -> list[WorkflowModel]:
|
||||
async with db_client.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(WorkflowModel)
|
||||
.options(selectinload(WorkflowModel.definitions))
|
||||
.where(WorkflowModel.organization_id == organization_id)
|
||||
)
|
||||
return list(result.scalars().unique().all())
|
||||
|
||||
|
||||
def _merge_service_secret_fields(incoming: dict, existing: dict):
|
||||
if (
|
||||
incoming.get("provider") is not None
|
||||
and existing.get("provider") is not None
|
||||
and incoming.get("provider") != existing.get("provider")
|
||||
):
|
||||
return
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
if secret_field not in existing:
|
||||
continue
|
||||
incoming_secret = incoming.get(secret_field)
|
||||
existing_secret = existing[secret_field]
|
||||
if incoming_secret is None:
|
||||
incoming[secret_field] = existing_secret
|
||||
elif contains_masked_key(incoming_secret):
|
||||
incoming[secret_field] = resolve_masked_api_keys(
|
||||
incoming_secret,
|
||||
existing_secret,
|
||||
)
|
||||
|
||||
|
||||
def _raise_if_masked_secret(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
if key in SERVICE_SECRET_FIELDS and contains_masked_key(nested):
|
||||
raise ValueError(
|
||||
f"The {key} appears to be masked. Please provide the actual "
|
||||
"value, not the masked value."
|
||||
)
|
||||
_raise_if_masked_secret(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_raise_if_masked_secret(item)
|
||||
|
||||
|
||||
def _mask_secret_fields(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_secret_fields(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_secret_fields(item)
|
||||
|
||||
|
||||
def _mask_secret_value(value):
|
||||
if isinstance(value, list):
|
||||
return [mask_key(item) for item in value]
|
||||
return mask_key(value)
|
||||
|
||||
|
||||
def _has_model_services(configuration: EffectiveAIModelConfiguration) -> bool:
|
||||
return any(
|
||||
service is not None
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_any_dograh_legacy_configuration(
|
||||
configuration: EffectiveAIModelConfiguration,
|
||||
dograh_key: str,
|
||||
) -> OrganizationAIModelConfigurationV2:
|
||||
speed = getattr(configuration.tts, "speed", 1.0)
|
||||
if speed not in DOGRAH_SPEED_OPTIONS:
|
||||
speed = 1.0
|
||||
return OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key=dograh_key,
|
||||
voice=getattr(configuration.tts, "voice", DOGRAH_DEFAULT_VOICE)
|
||||
or DOGRAH_DEFAULT_VOICE,
|
||||
speed=speed,
|
||||
language=getattr(configuration.stt, "language", DOGRAH_DEFAULT_LANGUAGE)
|
||||
or DOGRAH_DEFAULT_LANGUAGE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _first_dograh_api_key(configuration: EffectiveAIModelConfiguration) -> str | None:
|
||||
for service in (
|
||||
configuration.llm,
|
||||
configuration.tts,
|
||||
configuration.stt,
|
||||
configuration.embeddings,
|
||||
configuration.realtime,
|
||||
):
|
||||
if service is None or _provider(service) != ServiceProviders.DOGRAH:
|
||||
continue
|
||||
try:
|
||||
return _single_api_key(service)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _provider(service):
|
||||
return getattr(service, "provider", None)
|
||||
|
||||
|
||||
def _single_api_key(service) -> str:
|
||||
if hasattr(service, "get_all_api_keys"):
|
||||
keys = service.get_all_api_keys()
|
||||
if len(keys) != 1:
|
||||
raise ValueError("Expected exactly one API key")
|
||||
return keys[0]
|
||||
key = getattr(service, "api_key", None)
|
||||
if not key:
|
||||
raise ValueError("Expected an API key")
|
||||
return key
|
||||
|
|
@ -151,21 +151,35 @@ def mask_workflow_configurations(config: Optional[Dict]) -> Optional[Dict]:
|
|||
|
||||
masked = copy.deepcopy(config)
|
||||
model_overrides = masked.get("model_overrides")
|
||||
if not isinstance(model_overrides, dict):
|
||||
return masked
|
||||
if isinstance(model_overrides, dict):
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
|
||||
for section in MODEL_OVERRIDE_FIELDS:
|
||||
override = model_overrides.get(section)
|
||||
if not isinstance(override, dict):
|
||||
continue
|
||||
for secret_field in SERVICE_SECRET_FIELDS:
|
||||
raw = override.get(secret_field)
|
||||
if raw:
|
||||
override[secret_field] = _mask_secret_value(raw)
|
||||
v2_override = masked.get("model_configuration_v2_override")
|
||||
if isinstance(v2_override, dict):
|
||||
_mask_nested_service_secrets(v2_override)
|
||||
|
||||
return masked
|
||||
|
||||
|
||||
def _mask_nested_service_secrets(value):
|
||||
if isinstance(value, dict):
|
||||
for key, nested in list(value.items()):
|
||||
if key in SERVICE_SECRET_FIELDS and nested:
|
||||
value[key] = _mask_secret_value(nested)
|
||||
else:
|
||||
_mask_nested_service_secrets(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
_mask_nested_service_secrets(item)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workflow definition helpers – mask / merge node API keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1472,11 +1472,26 @@ class AzureOpenAIEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
|||
)
|
||||
|
||||
|
||||
DOGRAH_EMBEDDING_MODELS = ["default"]
|
||||
|
||||
|
||||
@register_embeddings
|
||||
class DograhEmbeddingsConfiguration(BaseEmbeddingsConfiguration):
|
||||
model_config = DOGRAH_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.DOGRAH] = ServiceProviders.DOGRAH
|
||||
model: str = Field(
|
||||
default="default",
|
||||
description="Dograh-managed embedding model.",
|
||||
json_schema_extra={"examples": DOGRAH_EMBEDDING_MODELS},
|
||||
)
|
||||
|
||||
|
||||
EmbeddingsConfig = Annotated[
|
||||
Union[
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenRouterEmbeddingsConfiguration,
|
||||
AzureOpenAIEmbeddingsConfiguration,
|
||||
DograhEmbeddingsConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
62
api/services/organization_preferences.py
Normal file
62
api/services/organization_preferences.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from inspect import isawaitable
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
|
||||
|
||||
async def get_organization_preferences(
|
||||
organization_id: int | None,
|
||||
db=None,
|
||||
) -> OrganizationPreferences:
|
||||
if organization_id is None:
|
||||
return OrganizationPreferences()
|
||||
|
||||
db = db or db_client
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
)
|
||||
if row is None:
|
||||
row = await _get_configuration(
|
||||
db,
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
|
||||
)
|
||||
return _parse_preferences(row.value if row is not None else None, organization_id)
|
||||
|
||||
|
||||
async def upsert_organization_preferences(
|
||||
organization_id: int,
|
||||
preferences: OrganizationPreferences,
|
||||
) -> OrganizationPreferences:
|
||||
await db_client.upsert_configuration(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
|
||||
preferences.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return preferences
|
||||
|
||||
|
||||
async def _get_configuration(db, organization_id: int, key: str):
|
||||
row = db.get_configuration(organization_id, key)
|
||||
if isawaitable(row):
|
||||
row = await row
|
||||
return row
|
||||
|
||||
|
||||
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
|
||||
if not value or not isinstance(value, dict):
|
||||
return OrganizationPreferences()
|
||||
try:
|
||||
return OrganizationPreferences.model_validate(value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
"Invalid organization preferences for organization "
|
||||
f"{organization_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OrganizationPreferences()
|
||||
|
|
@ -195,14 +195,17 @@ async def run_pipeline_telephony(
|
|||
# Resolve effective user config here so the transport can tune its
|
||||
# bot-stopped-speaking fallback based on is_realtime; pass the resolved
|
||||
# values into _run_pipeline so it doesn't fetch them again.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -272,15 +275,18 @@ async def run_pipeline_smallwebrtc(
|
|||
# Resolve workflow_run + effective user_config here so the transport can
|
||||
# tune its bot-stopped-speaking fallback based on is_realtime. _run_pipeline
|
||||
# reuses these via kwargs so we don't fetch twice.
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_run = await db_client.get_workflow_run(workflow_run_id, user_id)
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
run_configs = (
|
||||
(workflow_run.definition.workflow_configurations or {}) if workflow_run else {}
|
||||
)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id if workflow else None,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
is_realtime = bool(user_config.is_realtime and user_config.realtime is not None)
|
||||
|
||||
|
|
@ -380,11 +386,14 @@ async def _run_pipeline(
|
|||
# Resolve model overrides from the version onto global user config (skip
|
||||
# when the caller already resolved it).
|
||||
if resolved_user_config is None:
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await db_client.get_user_configurations(user_id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
else:
|
||||
user_config = resolved_user_config
|
||||
|
|
@ -508,10 +517,17 @@ async def _run_pipeline(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if user_config and user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(user_config.embeddings, "api_version", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,10 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
|
|
@ -48,17 +50,20 @@ async def check_dograh_quota(
|
|||
if quota is insufficient.
|
||||
"""
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
organization_id = user.selected_organization_id
|
||||
workflow_configurations = None
|
||||
|
||||
if workflow_id is not None:
|
||||
workflow = await db_client.get_workflow_by_id(workflow_id)
|
||||
if workflow:
|
||||
model_overrides = (workflow.workflow_configurations or {}).get(
|
||||
"model_overrides"
|
||||
)
|
||||
if model_overrides:
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
organization_id = workflow.organization_id
|
||||
workflow_configurations = workflow.workflow_configurations
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
|
|
@ -76,6 +81,13 @@ async def check_dograh_quota(
|
|||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
if (
|
||||
user_config.embeddings
|
||||
and user_config.embeddings.provider == ServiceProviders.DOGRAH
|
||||
):
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.embeddings.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
|
@ -84,7 +96,9 @@ async def check_dograh_quota(
|
|||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
api_key,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import random
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import WorkflowRunModel
|
||||
from api.services.workflow.dto import QANodeData
|
||||
|
||||
|
|
@ -54,7 +53,27 @@ async def resolve_user_llm_config(
|
|||
|
||||
llm_config: dict = {}
|
||||
if user_id:
|
||||
user_configuration = await db_client.get_user_configurations(user_id)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
workflow_configurations = {}
|
||||
if workflow_run.definition:
|
||||
workflow_configurations = (
|
||||
workflow_run.definition.workflow_configurations or {}
|
||||
)
|
||||
elif workflow_run.workflow:
|
||||
workflow_configurations = (
|
||||
workflow_run.workflow.workflow_configurations or {}
|
||||
)
|
||||
|
||||
user_configuration = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=user_id,
|
||||
organization_id=workflow_run.workflow.organization_id
|
||||
if workflow_run.workflow
|
||||
else None,
|
||||
workflow_configurations=workflow_configurations,
|
||||
)
|
||||
llm_config = user_configuration.model_dump(exclude_none=True).get("llm", {})
|
||||
|
||||
provider = llm_config.get("provider", "openai")
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from pipecat.utils.run_context import set_current_org_id
|
|||
|
||||
from api.db import db_client
|
||||
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
from api.services.pipecat.audio_config import create_audio_config
|
||||
from api.services.pipecat.pipeline_builder import create_pipeline_task
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import (
|
||||
|
|
@ -410,9 +409,14 @@ async def execute_text_chat_pending_turn(
|
|||
run_definition = workflow_run.definition
|
||||
run_configs = run_definition.workflow_configurations or {}
|
||||
|
||||
user_config = await db_client.get_user_configurations(workflow_run.workflow.user.id)
|
||||
user_config = resolve_effective_config(
|
||||
user_config, run_configs.get("model_overrides")
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_effective_ai_model_configuration_for_workflow,
|
||||
)
|
||||
|
||||
user_config = await get_effective_ai_model_configuration_for_workflow(
|
||||
user_id=workflow_run.workflow.user.id,
|
||||
organization_id=workflow.organization_id,
|
||||
workflow_configurations=run_configs,
|
||||
)
|
||||
if user_config.llm is None:
|
||||
raise ValueError("Text chat requires an LLM configuration")
|
||||
|
|
@ -466,9 +470,17 @@ async def execute_text_chat_pending_turn(
|
|||
embeddings_model = None
|
||||
embeddings_base_url = None
|
||||
if user_config.embeddings:
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
)
|
||||
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
|
||||
has_recordings = await db_client.has_active_recordings(workflow.organization_id)
|
||||
context_compaction_enabled = (workflow.workflow_configurations or {}).get(
|
||||
|
|
|
|||
|
|
@ -157,12 +157,24 @@ async def process_knowledge_base_document(
|
|||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
if document.created_by:
|
||||
user_config = await db_client.get_user_configurations(document.created_by)
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=document.created_by,
|
||||
organization_id=document.organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
if user_config.embeddings:
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_api_key = user_config.embeddings.api_key
|
||||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
|
|
|
|||
295
api/tests/test_ai_model_configuration_v2.py
Normal file
295
api/tests/test_ai_model_configuration_v2.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DograhManagedAIModelConfiguration,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
compile_ai_model_configuration_v2,
|
||||
)
|
||||
from api.schemas.user_configuration import UserConfiguration
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_configuration_model_override_to_v2,
|
||||
)
|
||||
from api.services.configuration.masking import mask_key
|
||||
from api.services.configuration.registry import (
|
||||
DeepgramSTTConfiguration,
|
||||
DograhLLMService,
|
||||
DograhSTTService,
|
||||
DograhTTSService,
|
||||
ElevenlabsTTSConfiguration,
|
||||
OpenAIEmbeddingsConfiguration,
|
||||
OpenAILLMService,
|
||||
)
|
||||
|
||||
|
||||
def test_dograh_v2_compiles_to_effective_managed_pipeline_with_embeddings():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
voice="default",
|
||||
speed=1.2,
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
effective = compile_ai_model_configuration_v2(config)
|
||||
|
||||
assert effective.is_realtime is False
|
||||
assert effective.llm.provider == "dograh"
|
||||
assert effective.llm.model == "default"
|
||||
assert effective.tts.provider == "dograh"
|
||||
assert effective.tts.speed == 1.2
|
||||
assert effective.stt.provider == "dograh"
|
||||
assert effective.stt.language == "multi"
|
||||
assert effective.embeddings.provider == "dograh"
|
||||
assert effective.embeddings.model == "default"
|
||||
|
||||
|
||||
def test_dograh_v2_rejects_non_predefined_speed():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(
|
||||
api_key="mps-secret",
|
||||
speed=1.5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_byok_v2_rejects_dograh_provider():
|
||||
with pytest.raises(ValidationError):
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
{
|
||||
"mode": "byok",
|
||||
"byok": {
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-secret",
|
||||
"model": "default",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_masked_dograh_key_is_preserved_when_saving_same_mode():
|
||||
existing = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key="mps-real-secret"),
|
||||
)
|
||||
incoming = OrganizationAIModelConfigurationV2(
|
||||
mode="dograh",
|
||||
dograh=DograhManagedAIModelConfiguration(api_key=mask_key("mps-real-secret")),
|
||||
)
|
||||
|
||||
merged = merge_ai_model_configuration_v2_secrets(incoming, existing)
|
||||
|
||||
assert merged.dograh.api_key == "mps-real-secret"
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(merged)
|
||||
|
||||
|
||||
def test_masked_v2_configuration_masks_nested_service_keys():
|
||||
config = OrganizationAIModelConfigurationV2(
|
||||
mode="byok",
|
||||
byok={
|
||||
"mode": "pipeline",
|
||||
"pipeline": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"api_key": "sk-real-secret",
|
||||
"model": "gpt-4.1",
|
||||
},
|
||||
"tts": {
|
||||
"provider": "elevenlabs",
|
||||
"api_key": "el-real-secret",
|
||||
"model": "eleven_flash_v2_5",
|
||||
"voice": "Rachel",
|
||||
},
|
||||
"stt": {
|
||||
"provider": "deepgram",
|
||||
"api_key": "dg-real-secret",
|
||||
"model": "nova-3-general",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
masked = mask_ai_model_configuration_v2(config)
|
||||
|
||||
assert masked["byok"]["pipeline"]["llm"]["api_key"] == mask_key("sk-real-secret")
|
||||
assert masked["byok"]["pipeline"]["tts"]["api_key"] == mask_key("el-real-secret")
|
||||
assert masked["byok"]["pipeline"]["stt"]["api_key"] == mask_key("dg-real-secret")
|
||||
|
||||
|
||||
def test_legacy_all_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=DograhLLMService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
voice="default",
|
||||
speed=1.0,
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key=["mps-secret"],
|
||||
model="default",
|
||||
language="multi",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-secret"
|
||||
|
||||
|
||||
def test_legacy_mixed_dograh_pipeline_converts_to_dograh_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=DograhTTSService(
|
||||
provider="dograh",
|
||||
api_key="mps-tts",
|
||||
model="default",
|
||||
voice="default",
|
||||
),
|
||||
stt=DograhSTTService(
|
||||
provider="dograh",
|
||||
api_key="mps-stt",
|
||||
model="default",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "dograh"
|
||||
assert config.dograh.api_key == "mps-tts"
|
||||
assert config.dograh.voice == "default"
|
||||
|
||||
|
||||
def test_legacy_byok_pipeline_converts_to_byok_v2():
|
||||
legacy = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
embeddings=OpenAIEmbeddingsConfiguration(
|
||||
provider="openai",
|
||||
api_key="sk-emb",
|
||||
model="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
config = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
||||
assert config.mode == "byok"
|
||||
assert config.byok.mode == "pipeline"
|
||||
assert config.byok.pipeline.llm.provider == "openai"
|
||||
assert config.byok.pipeline.tts.provider == "elevenlabs"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_v1_override_and_sets_v2():
|
||||
base = UserConfiguration(
|
||||
llm=OpenAILLMService(
|
||||
provider="openai",
|
||||
api_key="sk-llm",
|
||||
model="gpt-4.1",
|
||||
),
|
||||
tts=ElevenlabsTTSConfiguration(
|
||||
provider="elevenlabs",
|
||||
api_key="el-tts",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="Rachel",
|
||||
),
|
||||
stt=DeepgramSTTConfiguration(
|
||||
provider="deepgram",
|
||||
api_key="dg-stt",
|
||||
model="nova-3-general",
|
||||
),
|
||||
)
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": {
|
||||
"tts": {
|
||||
"provider": "dograh",
|
||||
"api_key": "mps-workflow",
|
||||
"model": "default",
|
||||
"voice": "default",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
v2_override = migrated[WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY]
|
||||
assert v2_override["mode"] == "dograh"
|
||||
assert v2_override["dograh"]["api_key"] == "mps-workflow"
|
||||
|
||||
|
||||
def test_workflow_model_override_migration_removes_invalid_v1_override_marker():
|
||||
base = UserConfiguration()
|
||||
workflow_configurations = {
|
||||
"ambient_noise_configuration": {"enabled": False},
|
||||
"model_overrides": None,
|
||||
}
|
||||
|
||||
migrated, changed = migrate_workflow_configuration_model_override_to_v2(
|
||||
workflow_configurations,
|
||||
base,
|
||||
)
|
||||
|
||||
assert changed is True
|
||||
assert "model_overrides" not in migrated
|
||||
assert migrated["ambient_noise_configuration"] == {"enabled": False}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
|
@ -14,14 +15,14 @@ from api.services.configuration.registry import (
|
|||
)
|
||||
|
||||
|
||||
def _make_test_app():
|
||||
def _make_test_app(selected_organization_id=None):
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.is_superuser = False
|
||||
mock_user.selected_organization_id = None
|
||||
mock_user.selected_organization_id = selected_organization_id
|
||||
|
||||
app.dependency_overrides[get_user] = lambda: mock_user
|
||||
return app
|
||||
|
|
@ -210,3 +211,38 @@ class TestMaskedKeyRejection:
|
|||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_preference_only_update_does_not_validate_or_save_model_config(self):
|
||||
"""Saving a test phone number through the legacy endpoint must not touch models."""
|
||||
app = _make_test_app(selected_organization_id=11)
|
||||
client = TestClient(app)
|
||||
preferences = SimpleNamespace(test_phone_number=None, timezone=None)
|
||||
|
||||
with (
|
||||
patch("api.routes.user.db_client") as mock_db,
|
||||
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
|
||||
patch(
|
||||
"api.routes.user.get_organization_preferences",
|
||||
new=AsyncMock(return_value=preferences),
|
||||
),
|
||||
patch(
|
||||
"api.routes.user.upsert_organization_preferences",
|
||||
new=AsyncMock(return_value=preferences),
|
||||
) as upsert_preferences,
|
||||
):
|
||||
existing = _existing_openai_config()
|
||||
mock_db.get_user_configurations = AsyncMock(return_value=existing)
|
||||
mock_db.update_user_configuration = AsyncMock()
|
||||
mock_db.get_organization_by_id = AsyncMock(return_value=None)
|
||||
mock_validator.return_value.validate = AsyncMock()
|
||||
|
||||
response = client.put(
|
||||
"/user/configurations/user",
|
||||
json={"test_phone_number": "+15551234567"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["test_phone_number"] == "+15551234567"
|
||||
mock_db.update_user_configuration.assert_not_called()
|
||||
mock_validator.return_value.validate.assert_not_called()
|
||||
upsert_preferences.assert_awaited_once()
|
||||
|
|
|
|||
|
|
@ -103,6 +103,61 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
|
|||
assert initiate_kwargs["workflow_id"] == workflow.id
|
||||
assert initiate_kwargs["user_id"] == workflow.user_id
|
||||
assert "user_id=99" in initiate_kwargs["webhook_url"]
|
||||
mock_db.get_user_configurations.assert_not_called()
|
||||
|
||||
|
||||
def test_initiate_call_uses_organization_preference_phone_number():
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
workflow = _workflow()
|
||||
provider = _provider()
|
||||
quota_mock = AsyncMock(
|
||||
return_value=SimpleNamespace(has_quota=True, error_message="")
|
||||
)
|
||||
|
||||
with (
|
||||
patch("api.routes.telephony.db_client") as mock_db,
|
||||
patch(
|
||||
"api.routes.telephony.check_dograh_quota_by_user_id",
|
||||
new=quota_mock,
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_default_telephony_provider",
|
||||
new=AsyncMock(return_value=provider),
|
||||
),
|
||||
patch(
|
||||
"api.routes.telephony.get_backend_endpoints",
|
||||
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
|
||||
),
|
||||
):
|
||||
mock_db.get_user_configurations = AsyncMock(
|
||||
return_value=SimpleNamespace(test_phone_number="+15550000000")
|
||||
)
|
||||
mock_db.get_configuration = Mock(
|
||||
return_value=SimpleNamespace(value={"test_phone_number": "+15557654321"})
|
||||
)
|
||||
mock_db.get_default_telephony_configuration = AsyncMock(
|
||||
return_value=SimpleNamespace(id=55)
|
||||
)
|
||||
mock_db.get_workflow = AsyncMock(return_value=workflow)
|
||||
mock_db.create_workflow_run = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
id=501,
|
||||
name="WR-TEL-OUT-00000001",
|
||||
initial_context={},
|
||||
)
|
||||
)
|
||||
mock_db.update_workflow_run = AsyncMock()
|
||||
|
||||
response = client.post(
|
||||
"/telephony/initiate-call",
|
||||
json={"workflow_id": workflow.id},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert provider.initiate_call.await_args.kwargs["to_number"] == "+15557654321"
|
||||
mock_db.get_user_configurations.assert_not_called()
|
||||
|
||||
|
||||
def test_initiate_call_rejects_existing_run_for_different_workflow():
|
||||
|
|
|
|||
|
|
@ -51,6 +51,38 @@ async def _create_user_and_workflow(
|
|||
return user, workflow
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_requires_selected_organization():
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from api.app import app
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
user = UserModel(provider_id="textchat-user-no-selected-org")
|
||||
|
||||
async def mock_get_user():
|
||||
return user
|
||||
|
||||
original_override = app.dependency_overrides.get(get_user)
|
||||
app.dependency_overrides[get_user] = mock_get_user
|
||||
|
||||
try:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/workflow/123/text-chat/sessions", json={}
|
||||
)
|
||||
finally:
|
||||
if original_override:
|
||||
app.dependency_overrides[get_user] = original_override
|
||||
else:
|
||||
app.dependency_overrides.pop(get_user, None)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "No organization selected"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
||||
db_session,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue