fix: decouple org preference and ai model preferences

This commit is contained in:
Abhishek Kumar 2026-06-09 15:40:34 +05:30
parent e26d902425
commit 01d898fc72
21 changed files with 460 additions and 238 deletions

View file

@ -447,11 +447,21 @@ class OrganizationUsageClient(BaseDBClient):
pref_result = await session.execute(
select(OrganizationConfigurationModel).where(
OrganizationConfigurationModel.organization_id == organization_id,
OrganizationConfigurationModel.key
== OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
OrganizationConfigurationModel.key.in_(
[
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
]
),
)
)
pref_obj = pref_result.scalar_one_or_none()
pref_rows = pref_result.scalars().all()
pref_by_key = {pref.key: pref for pref in pref_rows}
pref_obj = pref_by_key.get(
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value
) or pref_by_key.get(
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value
)
if pref_obj and pref_obj.value:
user_timezone = pref_obj.value.get("timezone") or user_timezone

View file

@ -92,9 +92,8 @@ class OrganizationConfigurationKey(Enum):
MODEL_CONFIGURATION_V2 = (
"MODEL_CONFIGURATION_V2" # Org-level v2 AI model configuration
)
MODEL_CONFIGURATION_PREFERENCES = (
"MODEL_CONFIGURATION_PREFERENCES" # Org-level model configuration preferences
)
ORGANIZATION_PREFERENCES = "ORGANIZATION_PREFERENCES" # Org-level defaults such as timezone/test call number
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
class WorkflowStatus(Enum):

View file

@ -14,10 +14,10 @@ from api.schemas.ai_model_configuration import (
DOGRAH_DEFAULT_LANGUAGE,
DOGRAH_DEFAULT_VOICE,
DOGRAH_SPEED_OPTIONS,
OrganizationAIModelConfigurationPreferences,
OrganizationAIModelConfigurationResponse,
OrganizationAIModelConfigurationV2,
)
from api.schemas.organization_preferences import OrganizationPreferences
from api.schemas.telephony_config import (
TelephonyConfigRequest,
TelephonyConfigurationCreateRequest,
@ -39,13 +39,11 @@ from api.services.configuration.ai_model_configuration import (
check_for_masked_keys_in_ai_model_configuration_v2,
compile_ai_model_configuration_v2,
convert_legacy_ai_model_configuration_to_v2,
get_organization_ai_model_configuration_preferences,
get_organization_ai_model_configuration_v2,
get_resolved_ai_model_configuration,
mask_ai_model_configuration_v2,
merge_ai_model_configuration_v2_secrets,
migrate_workflow_model_configurations_to_v2,
upsert_organization_ai_model_configuration_preferences,
upsert_organization_ai_model_configuration_v2,
)
from api.services.configuration.check_validity import UserConfigurationValidator
@ -57,6 +55,10 @@ from api.services.configuration.registry import (
ServiceProviders,
ServiceType,
)
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
from api.services.posthog_client import capture_event
from api.services.telephony import registry as telephony_registry
from api.services.telephony.factory import get_telephony_provider_by_id
@ -218,8 +220,6 @@ async def _model_configuration_v2_response(
return OrganizationAIModelConfigurationResponse(
configuration=mask_ai_model_configuration_v2(raw_configuration),
effective_configuration=mask_user_config(resolved.effective),
preferences=resolved.preferences
or OrganizationAIModelConfigurationPreferences(),
source=resolved.source,
)
@ -363,30 +363,47 @@ async def migrate_model_configuration_v2(
)
@router.get(
"/model-configurations/preferences",
response_model=OrganizationAIModelConfigurationPreferences,
)
async def get_model_configuration_preferences(
@router.get("/preferences", response_model=OrganizationPreferences)
async def get_preferences(
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await get_organization_ai_model_configuration_preferences(organization_id)
return await get_organization_preferences(organization_id)
@router.put("/preferences", response_model=OrganizationPreferences)
async def save_preferences(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await upsert_organization_preferences(
organization_id,
request,
)
@router.get(
"/model-configurations/preferences",
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def get_model_configuration_preferences_legacy(
user: UserModel = Depends(get_user_with_selected_organization),
):
return await get_preferences(user=user)
@router.put(
"/model-configurations/preferences",
response_model=OrganizationAIModelConfigurationPreferences,
response_model=OrganizationPreferences,
include_in_schema=False,
)
async def save_model_configuration_preferences(
request: OrganizationAIModelConfigurationPreferences,
async def save_model_configuration_preferences_legacy(
request: OrganizationPreferences,
user: UserModel = Depends(get_user_with_selected_organization),
):
organization_id = user.selected_organization_id
return await upsert_organization_ai_model_configuration_preferences(
organization_id,
request,
)
return await save_preferences(request=request, user=user)
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):

View file

@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
workflow_run_id: int | None = None
phone_number: str | None = None
# Optional explicit telephony config to use for the test call. If omitted,
# falls back to the user's per-user default (when set), then the org default.
# falls back to the org default.
telephony_configuration_id: int | None = None
# Optional caller-ID phone number to dial out from. Must belong to the
# resolved telephony configuration; otherwise the provider picks one.
@ -82,12 +82,9 @@ async def initiate_call(
"""Initiate a call using the configured telephony provider from web browser. This is
supposed to be a test call method for the draft version of the agent."""
from api.services.configuration.ai_model_configuration import (
get_organization_ai_model_configuration_preferences,
)
from api.services.organization_preferences import get_organization_preferences
user_configuration = await db_client.get_user_configurations(user.id)
preferences = await get_organization_ai_model_configuration_preferences(
preferences = await get_organization_preferences(
user.selected_organization_id,
db=db_client,
)
@ -124,17 +121,12 @@ async def initiate_call(
detail="telephony_not_configured",
)
phone_number = (
request.phone_number
or preferences.test_phone_number
or user_configuration.test_phone_number
)
phone_number = request.phone_number or preferences.test_phone_number
if not phone_number:
raise HTTPException(
status_code=400,
detail="Phone number must be provided in request or set in user "
"configuration",
detail="Phone number must be provided in request or set in organization preferences",
)
workflow = await db_client.get_workflow(

View file

@ -11,9 +11,7 @@ from api.db.models import (
)
from api.services.auth.depends import get_user
from api.services.configuration.ai_model_configuration import (
get_organization_ai_model_configuration_preferences,
get_resolved_ai_model_configuration,
upsert_organization_ai_model_configuration_preferences,
)
from api.services.configuration.check_validity import (
APIKeyStatusResponse,
@ -24,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
from api.services.configuration.merge import merge_user_configurations
from api.services.configuration.registry import REGISTRY, ServiceType
from api.services.mps_service_key_client import mps_service_key_client
from api.services.organization_preferences import (
get_organization_preferences,
upsert_organization_preferences,
)
router = APIRouter(prefix="/user")
@ -101,13 +103,12 @@ async def get_user_configurations(
organization_id=user.selected_organization_id,
)
masked_config = mask_user_config(resolved_config.effective)
if resolved_config.preferences:
if resolved_config.preferences.test_phone_number is not None:
masked_config["test_phone_number"] = (
resolved_config.preferences.test_phone_number
)
if resolved_config.preferences.timezone is not None:
masked_config["timezone"] = resolved_config.preferences.timezone
if user.selected_organization_id:
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:
masked_config["timezone"] = preferences.timezone
# Add organization pricing info if available
if user.selected_organization_id:
@ -133,43 +134,49 @@ async def update_user_configurations(
# Remove organization_pricing from incoming dict as it's read-only
incoming_dict.pop("organization_pricing", None)
preferences_update = {
key: incoming_dict.pop(key)
for key in ("test_phone_number", "timezone")
if key in incoming_dict
}
# Merge via helper
try:
user_configurations = merge_user_configurations(existing_config, incoming_dict)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
if incoming_dict:
# Merge via helper
try:
user_configurations = merge_user_configurations(
existing_config, incoming_dict
)
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
check_for_masked_keys(user_configurations)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
try:
validator = UserConfigurationValidator()
await validator.validate(
user_configurations,
organization_id=user.selected_organization_id,
created_by=user.provider_id,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
except ValueError as e:
raise HTTPException(status_code=422, detail=e.args[0])
else:
user_configurations = existing_config
user_configurations = await db_client.update_user_configuration(
user.id, user_configurations
)
if user.selected_organization_id and (
request.test_phone_number is not None or request.timezone is not None
):
preferences = await get_organization_ai_model_configuration_preferences(
user.selected_organization_id
)
if request.test_phone_number is not None:
preferences.test_phone_number = request.test_phone_number
if request.timezone is not None:
preferences.timezone = request.timezone
await upsert_organization_ai_model_configuration_preferences(
if user.selected_organization_id and preferences_update:
preferences = await get_organization_preferences(user.selected_organization_id)
if "test_phone_number" in preferences_update:
preferences.test_phone_number = preferences_update["test_phone_number"]
if "timezone" in preferences_update:
preferences.timezone = preferences_update["timezone"]
await upsert_organization_preferences(
user.selected_organization_id,
preferences,
)
@ -177,9 +184,7 @@ async def update_user_configurations(
# Return masked version of updated config
masked_config = mask_user_config(user_configurations)
if user.selected_organization_id:
preferences = await get_organization_ai_model_configuration_preferences(
user.selected_organization_id
)
preferences = await get_organization_preferences(user.selected_organization_id)
if preferences.test_phone_number is not None:
masked_config["test_phone_number"] = preferences.test_phone_number
if preferences.timezone is not None:

View file

@ -93,15 +93,9 @@ class OrganizationAIModelConfigurationV2(BaseModel):
return self
class OrganizationAIModelConfigurationPreferences(BaseModel):
test_phone_number: str | None = None
timezone: str | None = None
class OrganizationAIModelConfigurationResponse(BaseModel):
configuration: dict | None
effective_configuration: dict
preferences: OrganizationAIModelConfigurationPreferences
source: Literal["organization_v2", "legacy_user_v1", "empty"]

View file

@ -0,0 +1,6 @@
from pydantic import BaseModel
class OrganizationPreferences(BaseModel):
test_phone_number: str | None = None
timezone: str | None = None

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import copy
from dataclasses import dataclass
from inspect import isawaitable
from typing import Literal
from loguru import logger
@ -22,7 +21,6 @@ from api.schemas.ai_model_configuration import (
BYOKPipelineAIModelConfiguration,
BYOKRealtimeAIModelConfiguration,
DograhManagedAIModelConfiguration,
OrganizationAIModelConfigurationPreferences,
OrganizationAIModelConfigurationV2,
compile_ai_model_configuration_v2,
)
@ -45,7 +43,6 @@ class ResolvedAIModelConfiguration:
effective: EffectiveAIModelConfiguration
source: AIModelConfigurationSource
organization_configuration: OrganizationAIModelConfigurationV2 | None = None
preferences: OrganizationAIModelConfigurationPreferences | None = None
@dataclass
@ -60,9 +57,6 @@ async def get_resolved_ai_model_configuration(
user_id: int | None,
organization_id: int | None,
) -> ResolvedAIModelConfiguration:
preferences = await get_organization_ai_model_configuration_preferences(
organization_id
)
organization_configuration = await get_organization_ai_model_configuration_v2(
organization_id
)
@ -71,21 +65,18 @@ async def get_resolved_ai_model_configuration(
effective=compile_ai_model_configuration_v2(organization_configuration),
source="organization_v2",
organization_configuration=organization_configuration,
preferences=preferences,
)
if user_id is None:
return ResolvedAIModelConfiguration(
effective=EffectiveAIModelConfiguration(),
source="empty",
preferences=preferences,
)
legacy = await db_client.get_user_configurations(user_id)
return ResolvedAIModelConfiguration(
effective=legacy,
source="legacy_user_v1" if _has_model_services(legacy) else "empty",
preferences=preferences,
)
@ -135,33 +126,6 @@ async def get_organization_ai_model_configuration_v2(
return None
async def get_organization_ai_model_configuration_preferences(
organization_id: int | None,
db=None,
) -> OrganizationAIModelConfigurationPreferences:
if organization_id is None:
return OrganizationAIModelConfigurationPreferences()
db = db or db_client
row = db.get_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
if isawaitable(row):
row = await row
if row is None or not row.value:
return OrganizationAIModelConfigurationPreferences()
if not isinstance(row.value, dict):
return OrganizationAIModelConfigurationPreferences()
try:
return OrganizationAIModelConfigurationPreferences.model_validate(row.value)
except ValidationError as exc:
logger.warning(
"Invalid org AI model configuration preferences for organization "
f"{organization_id}: {exc}. Returning defaults."
)
return OrganizationAIModelConfigurationPreferences()
async def upsert_organization_ai_model_configuration_v2(
organization_id: int,
configuration: OrganizationAIModelConfigurationV2,
@ -174,18 +138,6 @@ async def upsert_organization_ai_model_configuration_v2(
return configuration
async def upsert_organization_ai_model_configuration_preferences(
organization_id: int,
preferences: OrganizationAIModelConfigurationPreferences,
) -> OrganizationAIModelConfigurationPreferences:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
preferences.model_dump(mode="json", exclude_none=True),
)
return preferences
async def migrate_workflow_model_configurations_to_v2(
*,
organization_id: int,

View file

@ -0,0 +1,62 @@
from inspect import isawaitable
from loguru import logger
from pydantic import ValidationError
from api.db import db_client
from api.enums import OrganizationConfigurationKey
from api.schemas.organization_preferences import OrganizationPreferences
async def get_organization_preferences(
organization_id: int | None,
db=None,
) -> OrganizationPreferences:
if organization_id is None:
return OrganizationPreferences()
db = db or db_client
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
)
if row is None:
row = await _get_configuration(
db,
organization_id,
OrganizationConfigurationKey.MODEL_CONFIGURATION_PREFERENCES.value,
)
return _parse_preferences(row.value if row is not None else None, organization_id)
async def upsert_organization_preferences(
organization_id: int,
preferences: OrganizationPreferences,
) -> OrganizationPreferences:
await db_client.upsert_configuration(
organization_id,
OrganizationConfigurationKey.ORGANIZATION_PREFERENCES.value,
preferences.model_dump(mode="json", exclude_none=True),
)
return preferences
async def _get_configuration(db, organization_id: int, key: str):
row = db.get_configuration(organization_id, key)
if isawaitable(row):
row = await row
return row
def _parse_preferences(value, organization_id: int) -> OrganizationPreferences:
if not value or not isinstance(value, dict):
return OrganizationPreferences()
try:
return OrganizationPreferences.model_validate(value)
except ValidationError as exc:
logger.warning(
"Invalid organization preferences for organization "
f"{organization_id}: {exc}. Returning defaults."
)
return OrganizationPreferences()

View file

@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
@ -14,14 +15,14 @@ from api.services.configuration.registry import (
)
def _make_test_app():
def _make_test_app(selected_organization_id=None):
app = FastAPI()
app.include_router(router)
mock_user = MagicMock()
mock_user.id = 1
mock_user.is_superuser = False
mock_user.selected_organization_id = None
mock_user.selected_organization_id = selected_organization_id
app.dependency_overrides[get_user] = lambda: mock_user
return app
@ -210,3 +211,38 @@ class TestMaskedKeyRejection:
)
assert response.status_code == 200
def test_preference_only_update_does_not_validate_or_save_model_config(self):
"""Saving a test phone number through the legacy endpoint must not touch models."""
app = _make_test_app(selected_organization_id=11)
client = TestClient(app)
preferences = SimpleNamespace(test_phone_number=None, timezone=None)
with (
patch("api.routes.user.db_client") as mock_db,
patch("api.routes.user.UserConfigurationValidator") as mock_validator,
patch(
"api.routes.user.get_organization_preferences",
new=AsyncMock(return_value=preferences),
),
patch(
"api.routes.user.upsert_organization_preferences",
new=AsyncMock(return_value=preferences),
) as upsert_preferences,
):
existing = _existing_openai_config()
mock_db.get_user_configurations = AsyncMock(return_value=existing)
mock_db.update_user_configuration = AsyncMock()
mock_db.get_organization_by_id = AsyncMock(return_value=None)
mock_validator.return_value.validate = AsyncMock()
response = client.put(
"/user/configurations/user",
json={"test_phone_number": "+15551234567"},
)
assert response.status_code == 200
assert response.json()["test_phone_number"] == "+15551234567"
mock_db.update_user_configuration.assert_not_called()
mock_validator.return_value.validate.assert_not_called()
upsert_preferences.assert_awaited_once()

View file

@ -103,6 +103,61 @@ def test_initiate_call_executes_as_workflow_owner_for_shared_org_workflow():
assert initiate_kwargs["workflow_id"] == workflow.id
assert initiate_kwargs["user_id"] == workflow.user_id
assert "user_id=99" in initiate_kwargs["webhook_url"]
mock_db.get_user_configurations.assert_not_called()
def test_initiate_call_uses_organization_preference_phone_number():
app = _make_test_app()
client = TestClient(app)
workflow = _workflow()
provider = _provider()
quota_mock = AsyncMock(
return_value=SimpleNamespace(has_quota=True, error_message="")
)
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.check_dograh_quota_by_user_id",
new=quota_mock,
),
patch(
"api.routes.telephony.get_default_telephony_provider",
new=AsyncMock(return_value=provider),
),
patch(
"api.routes.telephony.get_backend_endpoints",
new=AsyncMock(return_value=("https://api.example.com", "wss://ignored")),
),
):
mock_db.get_user_configurations = AsyncMock(
return_value=SimpleNamespace(test_phone_number="+15550000000")
)
mock_db.get_configuration = Mock(
return_value=SimpleNamespace(value={"test_phone_number": "+15557654321"})
)
mock_db.get_default_telephony_configuration = AsyncMock(
return_value=SimpleNamespace(id=55)
)
mock_db.get_workflow = AsyncMock(return_value=workflow)
mock_db.create_workflow_run = AsyncMock(
return_value=SimpleNamespace(
id=501,
name="WR-TEL-OUT-00000001",
initial_context={},
)
)
mock_db.update_workflow_run = AsyncMock()
response = client.post(
"/telephony/initiate-call",
json={"workflow_id": workflow.id},
)
assert response.status_code == 200
assert provider.initiate_call.await_args.kwargs["to_number"] == "+15557654321"
mock_db.get_user_configurations.assert_not_called()
def test_initiate_call_rejects_existing_run_for_different_workflow():