mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-16 08:25:18 +02:00
feat: add config v2 to simplify billing (#428)
* feat: add model config v2 * chore: centralize user org selection * chore: move preferences to platform settings * fix: decouple org preference and ai model preferences
This commit is contained in:
parent
49e68b49d5
commit
cdbd06c8d9
42 changed files with 5135 additions and 264 deletions
|
|
@ -3,9 +3,12 @@ from loguru import logger
|
|||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import PostHogEvent
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.auth import AuthResponse, LoginRequest, SignupRequest, UserResponse
|
||||
from api.services.auth.depends import create_user_configuration_with_mps_key, get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.utils.auth import create_jwt_token, hash_password, verify_password
|
||||
|
||||
|
|
@ -47,6 +50,12 @@ async def signup(request: SignupRequest):
|
|||
)
|
||||
if mps_config:
|
||||
await db_client.update_user_configuration(user.id, mps_config)
|
||||
model_config_v2 = convert_legacy_ai_model_configuration_to_v2(mps_config)
|
||||
await db_client.upsert_configuration(
|
||||
organization.id,
|
||||
OrganizationConfigurationKey.MODEL_CONFIGURATION_V2.value,
|
||||
model_config_v2.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create default configuration for OSS user", exc_info=True
|
||||
|
|
|
|||
|
|
@ -369,6 +369,10 @@ async def search_chunks(
|
|||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
apply_managed_embeddings_base_url,
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.gen_ai import (
|
||||
AzureOpenAIEmbeddingService,
|
||||
|
|
@ -376,10 +380,15 @@ async def search_chunks(
|
|||
)
|
||||
|
||||
# Try to get user's embeddings configuration
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
embeddings_api_key = None
|
||||
embeddings_model = None
|
||||
embeddings_provider = None
|
||||
embeddings_base_url = None
|
||||
embeddings_endpoint = None
|
||||
embeddings_api_version = None
|
||||
|
||||
|
|
@ -388,6 +397,10 @@ async def search_chunks(
|
|||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_provider = getattr(user_config.embeddings, "provider", None)
|
||||
embeddings_endpoint = getattr(user_config.embeddings, "endpoint", None)
|
||||
embeddings_base_url = apply_managed_embeddings_base_url(
|
||||
provider=embeddings_provider,
|
||||
base_url=getattr(user_config.embeddings, "base_url", None),
|
||||
)
|
||||
embeddings_api_version = getattr(
|
||||
user_config.embeddings, "api_version", None
|
||||
)
|
||||
|
|
@ -406,9 +419,7 @@ async def search_chunks(
|
|||
db_client=db_client,
|
||||
api_key=embeddings_api_key,
|
||||
model_id=embeddings_model or "text-embedding-3-small",
|
||||
base_url=getattr(user_config.embeddings, "base_url", None)
|
||||
if user_config.embeddings
|
||||
else None,
|
||||
base_url=embeddings_base_url,
|
||||
)
|
||||
|
||||
# Perform search
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -10,6 +10,14 @@ from api.db import db_client
|
|||
from api.db.models import UserModel
|
||||
from api.db.telephony_configuration_client import TelephonyConfigurationInUseError
|
||||
from api.enums import OrganizationConfigurationKey, PostHogEvent
|
||||
from api.schemas.ai_model_configuration import (
|
||||
DOGRAH_DEFAULT_LANGUAGE,
|
||||
DOGRAH_DEFAULT_VOICE,
|
||||
DOGRAH_SPEED_OPTIONS,
|
||||
OrganizationAIModelConfigurationResponse,
|
||||
OrganizationAIModelConfigurationV2,
|
||||
)
|
||||
from api.schemas.organization_preferences import OrganizationPreferences
|
||||
from api.schemas.telephony_config import (
|
||||
TelephonyConfigRequest,
|
||||
TelephonyConfigurationCreateRequest,
|
||||
|
|
@ -26,8 +34,31 @@ from api.schemas.telephony_phone_number import (
|
|||
PhoneNumberUpdateRequest,
|
||||
ProviderSyncStatus,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.masking import is_mask_of, mask_key
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_organization_ai_model_configuration_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
mask_ai_model_configuration_v2,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
migrate_workflow_model_configurations_to_v2,
|
||||
upsert_organization_ai_model_configuration_v2,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.defaults import DEFAULT_SERVICE_PROVIDERS
|
||||
from api.services.configuration.masking import is_mask_of, mask_key, mask_user_config
|
||||
from api.services.configuration.registry import (
|
||||
DOGRAH_STT_LANGUAGES,
|
||||
REGISTRY,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
from api.services.posthog_client import capture_event
|
||||
from api.services.telephony import registry as telephony_registry
|
||||
from api.services.telephony.factory import get_telephony_provider_by_id
|
||||
|
|
@ -159,6 +190,222 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI model configurations v2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
for provider, model_cls in REGISTRY[service_type].items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
|
||||
|
||||
async def _model_configuration_v2_response(
|
||||
*,
|
||||
user: UserModel,
|
||||
configuration: OrganizationAIModelConfigurationV2 | None = None,
|
||||
) -> OrganizationAIModelConfigurationResponse:
|
||||
resolved = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
raw_configuration = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else resolved.organization_configuration
|
||||
)
|
||||
return OrganizationAIModelConfigurationResponse(
|
||||
configuration=mask_ai_model_configuration_v2(raw_configuration),
|
||||
effective_configuration=mask_user_config(resolved.effective),
|
||||
source=resolved.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/defaults")
|
||||
async def get_model_configuration_v2_defaults(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
byok_default_providers = {
|
||||
service: provider
|
||||
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
|
||||
if provider != ServiceProviders.DOGRAH.value
|
||||
}
|
||||
return {
|
||||
"dograh": {
|
||||
"voices": [DOGRAH_DEFAULT_VOICE],
|
||||
"speeds": list(DOGRAH_SPEED_OPTIONS),
|
||||
"languages": DOGRAH_STT_LANGUAGES,
|
||||
"defaults": {
|
||||
"voice": DOGRAH_DEFAULT_VOICE,
|
||||
"speed": 1.0,
|
||||
"language": DOGRAH_DEFAULT_LANGUAGE,
|
||||
},
|
||||
},
|
||||
"byok": {
|
||||
"pipeline": {
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"tts": _byok_provider_schemas(ServiceType.TTS),
|
||||
"stt": _byok_provider_schemas(ServiceType.STT),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
"realtime": {
|
||||
"realtime": _byok_provider_schemas(ServiceType.REALTIME),
|
||||
"llm": _byok_provider_schemas(ServiceType.LLM),
|
||||
"embeddings": _byok_provider_schemas(ServiceType.EMBEDDINGS),
|
||||
"default_providers": byok_default_providers,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def get_model_configuration_v2(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await _model_configuration_v2_response(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def save_model_configuration_v2(
|
||||
request: OrganizationAIModelConfigurationV2,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
|
||||
try:
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(configuration)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-configurations/v2/migration-preview")
|
||||
async def preview_model_configuration_v2_migration(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc))
|
||||
return {
|
||||
"configuration": mask_ai_model_configuration_v2(configuration),
|
||||
"effective_configuration": mask_user_config(
|
||||
compile_ai_model_configuration_v2(configuration)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/model-configurations/v2/migrate",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def migrate_model_configuration_v2(
|
||||
force: bool = Query(default=False),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
if existing is not None and not force:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Organization already has a v2 model configuration",
|
||||
)
|
||||
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
effective = compile_ai_model_configuration_v2(configuration)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.args[0])
|
||||
|
||||
await upsert_organization_ai_model_configuration_v2(
|
||||
organization_id,
|
||||
configuration,
|
||||
)
|
||||
await migrate_workflow_model_configurations_to_v2(
|
||||
organization_id=organization_id,
|
||||
fallback_user_config=legacy,
|
||||
)
|
||||
return await _model_configuration_v2_response(
|
||||
user=user,
|
||||
configuration=configuration,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/preferences", response_model=OrganizationPreferences)
|
||||
async def get_preferences(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await get_organization_preferences(organization_id)
|
||||
|
||||
|
||||
@router.put("/preferences", response_model=OrganizationPreferences)
|
||||
async def save_preferences(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await upsert_organization_preferences(
|
||||
organization_id,
|
||||
request,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def get_model_configuration_preferences_legacy(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await get_preferences(user=user)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/model-configurations/preferences",
|
||||
response_model=OrganizationPreferences,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def save_model_configuration_preferences_legacy(
|
||||
request: OrganizationPreferences,
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await save_preferences(request=request, user=user)
|
||||
|
||||
|
||||
def preserve_masked_fields(provider: str, request_dict: dict, existing: dict):
|
||||
"""If the client re-submitted a masked sensitive field, restore the original."""
|
||||
for field_name in _sensitive_fields(provider):
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class InitiateCallRequest(BaseModel):
|
|||
workflow_run_id: int | None = None
|
||||
phone_number: str | None = None
|
||||
# Optional explicit telephony config to use for the test call. If omitted,
|
||||
# falls back to the user's per-user default (when set), then the org default.
|
||||
# falls back to the org default.
|
||||
telephony_configuration_id: int | None = None
|
||||
# Optional caller-ID phone number to dial out from. Must belong to the
|
||||
# resolved telephony configuration; otherwise the provider picks one.
|
||||
|
|
@ -82,7 +82,12 @@ async def initiate_call(
|
|||
"""Initiate a call using the configured telephony provider from web browser. This is
|
||||
supposed to be a test call method for the draft version of the agent."""
|
||||
|
||||
user_configuration = await db_client.get_user_configurations(user.id)
|
||||
from api.services.organization_preferences import get_organization_preferences
|
||||
|
||||
preferences = await get_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
db=db_client,
|
||||
)
|
||||
|
||||
# Resolve which telephony config to use: explicit request value, otherwise
|
||||
# the org's default outbound config.
|
||||
|
|
@ -116,13 +121,12 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
phone_number = request.phone_number or user_configuration.test_phone_number
|
||||
phone_number = request.phone_number or preferences.test_phone_number
|
||||
|
||||
if not phone_number:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Phone number must be provided in request or set in user "
|
||||
"configuration",
|
||||
detail="Phone number must be provided in request or set in organization preferences",
|
||||
)
|
||||
|
||||
workflow = await db_client.get_workflow(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from api.db.models import (
|
|||
UserModel,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
)
|
||||
from api.services.configuration.check_validity import (
|
||||
APIKeyStatusResponse,
|
||||
UserConfigurationValidator,
|
||||
|
|
@ -19,6 +22,10 @@ from api.services.configuration.masking import check_for_masked_keys, mask_user_
|
|||
from api.services.configuration.merge import merge_user_configurations
|
||||
from api.services.configuration.registry import REGISTRY, ServiceType
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.organization_preferences import (
|
||||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
|
||||
|
|
@ -91,8 +98,17 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
async def get_user_configurations(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> UserConfigurationRequestResponseSchema:
|
||||
user_configurations = await db_client.get_user_configurations(user.id)
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
masked_config = mask_user_config(resolved_config.effective)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -118,34 +134,61 @@ async def update_user_configurations(
|
|||
|
||||
# Remove organization_pricing from incoming dict as it's read-only
|
||||
incoming_dict.pop("organization_pricing", None)
|
||||
preferences_update = {
|
||||
key: incoming_dict.pop(key)
|
||||
for key in ("test_phone_number", "timezone")
|
||||
if key in incoming_dict
|
||||
}
|
||||
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(existing_config, incoming_dict)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
if incoming_dict:
|
||||
# Merge via helper
|
||||
try:
|
||||
user_configurations = merge_user_configurations(
|
||||
existing_config, incoming_dict
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
try:
|
||||
check_for_masked_keys(user_configurations)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
try:
|
||||
validator = UserConfigurationValidator()
|
||||
await validator.validate(
|
||||
user_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=e.args[0])
|
||||
else:
|
||||
user_configurations = existing_config
|
||||
|
||||
user_configurations = await db_client.update_user_configuration(
|
||||
user.id, user_configurations
|
||||
)
|
||||
if user.selected_organization_id and preferences_update:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if "test_phone_number" in preferences_update:
|
||||
preferences.test_phone_number = preferences_update["test_phone_number"]
|
||||
if "timezone" in preferences_update:
|
||||
preferences.timezone = preferences_update["timezone"]
|
||||
await upsert_organization_preferences(
|
||||
user.selected_organization_id,
|
||||
preferences,
|
||||
)
|
||||
|
||||
# Return masked version of updated config
|
||||
masked_config = mask_user_config(user_configurations)
|
||||
if user.selected_organization_id:
|
||||
preferences = await get_organization_preferences(user.selected_organization_id)
|
||||
if preferences.test_phone_number is not None:
|
||||
masked_config["test_phone_number"] = preferences.test_phone_number
|
||||
if preferences.timezone is not None:
|
||||
masked_config["timezone"] = preferences.timezone
|
||||
|
||||
# Add organization pricing info if available
|
||||
if user.selected_organization_id:
|
||||
|
|
@ -165,7 +208,11 @@ async def validate_user_configurations(
|
|||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> APIKeyStatusResponse:
|
||||
configurations = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
configurations = resolved_config.effective
|
||||
|
||||
if (
|
||||
configurations.last_validated_at
|
||||
|
|
|
|||
|
|
@ -16,9 +16,18 @@ from api.db.agent_trigger_client import TriggerPathConflictError
|
|||
from api.db.models import UserModel
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
from api.enums import CallType, PostHogEvent, StorageBackend
|
||||
from api.schemas.ai_model_configuration import OrganizationAIModelConfigurationV2
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.sdk_expose import sdk_expose
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY,
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
convert_legacy_ai_model_configuration_to_v2,
|
||||
get_resolved_ai_model_configuration,
|
||||
merge_ai_model_configuration_v2_secrets,
|
||||
)
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.masking import (
|
||||
mask_workflow_configurations,
|
||||
|
|
@ -955,12 +964,74 @@ async def update_workflow(
|
|||
existing_def,
|
||||
)
|
||||
|
||||
# Validate model_overrides: resolve onto global config, then
|
||||
# run the same validator used by the user-configurations endpoint.
|
||||
# Also stamp the current global API key into the override so the override
|
||||
# remains functional if the global config later switches to a different provider.
|
||||
# Validate model overrides. v2 uses a complete workflow-level model
|
||||
# configuration; legacy v1 uses partial service overlays.
|
||||
workflow_configurations = request.workflow_configurations
|
||||
if workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
if workflow_configurations and workflow_configurations.get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if existing_workflow is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Workflow with id {workflow_id} not found"
|
||||
)
|
||||
existing_draft = await db_client.get_draft_version(workflow_id)
|
||||
existing_configs = (
|
||||
existing_draft.workflow_configurations
|
||||
if existing_draft
|
||||
else existing_workflow.released_definition.workflow_configurations
|
||||
)
|
||||
existing_v2_override = (existing_configs or {}).get(
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
)
|
||||
try:
|
||||
incoming_v2_override = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
workflow_configurations[
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY
|
||||
]
|
||||
)
|
||||
)
|
||||
existing_v2_override_config = (
|
||||
OrganizationAIModelConfigurationV2.model_validate(
|
||||
existing_v2_override
|
||||
)
|
||||
if existing_v2_override
|
||||
else None
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
incoming_v2_override,
|
||||
existing_v2_override_config,
|
||||
)
|
||||
if existing_v2_override_config is None:
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
v2_override = merge_ai_model_configuration_v2_secrets(
|
||||
v2_override,
|
||||
resolved_config.organization_configuration,
|
||||
)
|
||||
check_for_masked_keys_in_ai_model_configuration_v2(v2_override)
|
||||
effective = compile_ai_model_configuration_v2(v2_override)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except (ValidationError, ValueError) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
elif workflow_configurations and workflow_configurations.get("model_overrides"):
|
||||
existing_workflow = await db_client.get_workflow(
|
||||
workflow_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
|
@ -978,24 +1049,46 @@ async def update_workflow(
|
|||
workflow_configurations,
|
||||
existing_configs,
|
||||
)
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
resolved_config = await get_resolved_ai_model_configuration(
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
user_config = resolved_config.effective
|
||||
try:
|
||||
enriched_overrides = enrich_overrides_with_api_keys(
|
||||
workflow_configurations["model_overrides"],
|
||||
user_config,
|
||||
)
|
||||
effective = resolve_effective_config(user_config, enriched_overrides)
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
if resolved_config.source == "organization_v2":
|
||||
v2_override = convert_legacy_ai_model_configuration_to_v2(effective)
|
||||
await UserConfigurationValidator().validate(
|
||||
compile_ai_model_configuration_v2(v2_override),
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
else:
|
||||
await UserConfigurationValidator().validate(
|
||||
effective,
|
||||
organization_id=user.selected_organization_id,
|
||||
created_by=user.provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
if resolved_config.source == "organization_v2":
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
WORKFLOW_MODEL_CONFIGURATION_V2_OVERRIDE_KEY: v2_override.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
workflow_configurations.pop("model_overrides", None)
|
||||
else:
|
||||
workflow_configurations = {
|
||||
**workflow_configurations,
|
||||
"model_overrides": enriched_overrides,
|
||||
}
|
||||
|
||||
# Reject upfront if any new trigger path collides with another
|
||||
# workflow's trigger — keeps the workflow record from
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user_with_selected_organization
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
|
|
@ -96,12 +96,6 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _require_selected_organization_id(user: UserModel) -> int:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
if not quota_result.has_quota:
|
||||
|
|
@ -114,9 +108,8 @@ async def _load_text_session_or_404(
|
|||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=organization_id
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
|
|
@ -158,9 +151,8 @@ async def _execute_pending_turn_response(
|
|||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
|
|
@ -172,7 +164,7 @@ async def create_text_chat_session(
|
|||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
organization_id=organization_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
|
@ -220,7 +212,7 @@ async def create_text_chat_session(
|
|||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
|
@ -234,7 +226,7 @@ async def append_text_chat_message(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
|
@ -264,7 +256,7 @@ async def rewind_text_chat_session(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue