mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: refactor user configuration table
This commit is contained in:
parent
03daaba7a1
commit
e5cc1308ed
31 changed files with 932 additions and 419 deletions
|
|
@ -0,0 +1,52 @@
|
|||
"""add key to user_configurations
|
||||
|
||||
Turns user_configurations into a per-user keyed JSON store mirroring
|
||||
organization_configurations. Existing rows (the legacy v1 AI model
|
||||
configuration blob) are backfilled with key MODEL_CONFIGURATION.
|
||||
|
||||
Revision ID: 91cc6ba3e1c7
|
||||
Revises: 384be6596b36
|
||||
Create Date: 2026-06-12 21:04:25.561529
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "91cc6ba3e1c7"
|
||||
down_revision: Union[str, None] = "384be6596b36"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Backfill existing rows (all legacy model-config blobs) via the server
|
||||
# default, then drop the default — application code always supplies key.
|
||||
op.add_column(
|
||||
"user_configurations",
|
||||
sa.Column(
|
||||
"key",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="MODEL_CONFIGURATION",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_unique_constraint(
|
||||
"_user_configuration_key_uc", "user_configurations", ["user_id", "key"]
|
||||
)
|
||||
op.alter_column("user_configurations", "key", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"_user_configuration_key_uc", "user_configurations", type_="unique"
|
||||
)
|
||||
# Non-model-config rows (e.g. ONBOARDING) have no meaning in the old
|
||||
# single-blob schema; the old code would read them as the user's model
|
||||
# config, so they must not survive the downgrade.
|
||||
op.execute("DELETE FROM user_configurations WHERE key != 'MODEL_CONFIGURATION'")
|
||||
op.drop_column("user_configurations", "key")
|
||||
|
|
@ -82,12 +82,24 @@ class UserModel(Base):
|
|||
|
||||
|
||||
class UserConfigurationModel(Base):
|
||||
"""Per-user keyed JSON store, mirroring organization_configurations.
|
||||
|
||||
Keys are defined in UserConfigurationKey. The legacy v1 AI model
|
||||
configuration lives under MODEL_CONFIGURATION; last_validated_at is only
|
||||
meaningful for that key.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_configurations"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
key = Column(String, nullable=False)
|
||||
configuration = Column(JSON, nullable=False, default=dict)
|
||||
last_validated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "key", name="_user_configuration_key_uc"),
|
||||
)
|
||||
|
||||
|
||||
# New Organization model
|
||||
class OrganizationModel(Base):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from api.db.models import (
|
|||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.enums import OrganizationConfigurationKey, UserConfigurationKey
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
|
|
@ -343,7 +343,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
if user_id:
|
||||
config_result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
UserConfigurationModel.user_id == user_id,
|
||||
UserConfigurationModel.key
|
||||
== UserConfigurationKey.MODEL_CONFIGURATION.value,
|
||||
)
|
||||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.enums import UserConfigurationKey
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
|
|
@ -65,16 +66,51 @@ class UserClient(BaseDBClient):
|
|||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_user_configuration_row(
|
||||
self, session, user_id: int, key: str
|
||||
) -> UserConfigurationModel | None:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id,
|
||||
UserConfigurationModel.key == key,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_configuration_value(self, user_id: int, key: str) -> dict | None:
|
||||
"""Get the JSON value stored for a user under `key`, or None."""
|
||||
async with self.async_session() as session:
|
||||
row = await self._get_user_configuration_row(session, user_id, key)
|
||||
return row.configuration if row else None
|
||||
|
||||
async def upsert_user_configuration_value(
|
||||
self, user_id: int, key: str, value: dict
|
||||
) -> dict:
|
||||
"""Create or update the JSON value stored for a user under `key`."""
|
||||
async with self.async_session() as session:
|
||||
row = await self._get_user_configuration_row(session, user_id, key)
|
||||
if row:
|
||||
row.configuration = value
|
||||
else:
|
||||
row = UserConfigurationModel(
|
||||
user_id=user_id, key=key, configuration=value
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(row)
|
||||
return row.configuration
|
||||
|
||||
async def get_user_configurations(
|
||||
self, user_id: int
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
configuration_obj = await self._get_user_configuration_row(
|
||||
session, user_id, UserConfigurationKey.MODEL_CONFIGURATION.value
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
return EffectiveAIModelConfiguration()
|
||||
|
||||
|
|
@ -97,38 +133,18 @@ class UserClient(BaseDBClient):
|
|||
async def update_user_configuration(
|
||||
self, user_id: int, configuration: EffectiveAIModelConfiguration
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
configuration_obj = UserConfigurationModel(
|
||||
user_id=user_id, configuration=configuration.model_dump()
|
||||
)
|
||||
session.add(configuration_obj)
|
||||
else:
|
||||
configuration_obj.configuration = configuration.model_dump()
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(configuration_obj)
|
||||
return EffectiveAIModelConfiguration.model_validate(
|
||||
configuration_obj.configuration
|
||||
value = await self.upsert_user_configuration_value(
|
||||
user_id,
|
||||
UserConfigurationKey.MODEL_CONFIGURATION.value,
|
||||
configuration.model_dump(),
|
||||
)
|
||||
return EffectiveAIModelConfiguration.model_validate(value)
|
||||
|
||||
async def update_user_configuration_last_validated_at(self, user_id: int) -> None:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
configuration_obj = await self._get_user_configuration_row(
|
||||
session, user_id, UserConfigurationKey.MODEL_CONFIGURATION.value
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
raise ValueError(f"User configuration with ID {user_id} not found")
|
||||
configuration_obj.last_validated_at = datetime.now()
|
||||
|
|
|
|||
|
|
@ -96,6 +96,15 @@ class OrganizationConfigurationKey(Enum):
|
|||
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
|
||||
|
||||
|
||||
class UserConfigurationKey(Enum):
|
||||
"""Keys for the per-user keyed JSON store (user_configurations)."""
|
||||
|
||||
MODEL_CONFIGURATION = (
|
||||
"MODEL_CONFIGURATION" # Legacy per-user v1 AI model configuration
|
||||
)
|
||||
ONBOARDING = "ONBOARDING" # Post-signup onboarding state (gate, tooltips, actions)
|
||||
|
||||
|
||||
class WorkflowStatus(Enum):
|
||||
"""Workflow status values"""
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.db import db_client
|
|||
from api.db.models import (
|
||||
UserModel,
|
||||
)
|
||||
from api.schemas.onboarding_state import OnboardingState, OnboardingStateUpdate
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
|
|
@ -26,6 +27,10 @@ from api.services.organization_preferences import (
|
|||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
from api.services.user_onboarding import (
|
||||
get_onboarding_state,
|
||||
update_onboarding_state,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
|
||||
|
|
@ -92,9 +97,6 @@ class UserConfigurationRequestResponseSchema(BaseModel):
|
|||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
organization_pricing: dict[str, Union[float, str, bool]] | None = None
|
||||
# Post-signup onboarding gate. Set once on submit/skip.
|
||||
onboarding_completed_at: datetime | None = None
|
||||
onboarding_skipped: bool | None = None
|
||||
|
||||
|
||||
@router.get("/configurations/user")
|
||||
|
|
@ -206,6 +208,21 @@ async def update_user_configurations(
|
|||
return masked_config
|
||||
|
||||
|
||||
@router.get("/onboarding-state")
|
||||
async def get_user_onboarding_state(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> OnboardingState:
|
||||
return await get_onboarding_state(user.id)
|
||||
|
||||
|
||||
@router.put("/onboarding-state")
|
||||
async def update_user_onboarding_state(
|
||||
request: OnboardingStateUpdate,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> OnboardingState:
|
||||
return await update_onboarding_state(user.id, request)
|
||||
|
||||
|
||||
@router.get("/configurations/user/validate")
|
||||
async def validate_user_configurations(
|
||||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
|
|
|
|||
|
|
@ -34,10 +34,6 @@ class EffectiveAIModelConfiguration(BaseModel):
|
|||
test_phone_number: str | None = None
|
||||
timezone: str | None = None
|
||||
last_validated_at: datetime | None = None
|
||||
# Post-signup onboarding gate: set once the user submits or skips the
|
||||
# onboarding form, so it shows only once per user.
|
||||
onboarding_completed_at: datetime | None = None
|
||||
onboarding_skipped: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
47
api/schemas/onboarding_state.py
Normal file
47
api/schemas/onboarding_state.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OnboardingState(BaseModel):
|
||||
"""Per-user onboarding state, stored under UserConfigurationKey.ONBOARDING.
|
||||
|
||||
Server-authoritative replacement for the browser-localStorage onboarding
|
||||
store, so the post-signup gate and one-time tooltips hold across devices.
|
||||
"""
|
||||
|
||||
# Post-signup onboarding form gate: set once on submit/skip.
|
||||
completed_at: datetime | None = None
|
||||
skipped: bool = False
|
||||
# One-time UI affordances (tooltip keys, milestone action keys). Kept as
|
||||
# free-form strings — the UI owns the vocabulary.
|
||||
seen_tooltips: list[str] = Field(default_factory=list)
|
||||
completed_actions: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OnboardingStateUpdate(BaseModel):
|
||||
"""Partial update merged into the stored state.
|
||||
|
||||
Scalars overwrite when supplied; list entries are unioned into the stored
|
||||
lists, so concurrent updates (e.g. two tabs marking different tooltips)
|
||||
don't drop each other's items.
|
||||
"""
|
||||
|
||||
completed_at: datetime | None = None
|
||||
skipped: bool | None = None
|
||||
seen_tooltips: list[str] | None = None
|
||||
completed_actions: list[str] | None = None
|
||||
|
||||
def apply_to(self, state: OnboardingState) -> OnboardingState:
|
||||
merged = state.model_copy(deep=True)
|
||||
if self.completed_at is not None:
|
||||
merged.completed_at = self.completed_at
|
||||
if self.skipped is not None:
|
||||
merged.skipped = self.skipped
|
||||
for tooltip in self.seen_tooltips or []:
|
||||
if tooltip not in merged.seen_tooltips:
|
||||
merged.seen_tooltips.append(tooltip)
|
||||
for action in self.completed_actions or []:
|
||||
if action not in merged.completed_actions:
|
||||
merged.completed_actions.append(action)
|
||||
return merged
|
||||
|
|
@ -141,10 +141,6 @@ def mask_user_config(config: EffectiveAIModelConfiguration) -> Dict[str, Any]:
|
|||
"is_realtime": config.is_realtime,
|
||||
"test_phone_number": config.test_phone_number,
|
||||
"timezone": config.timezone,
|
||||
# Onboarding gate flags (not secrets) — surfaced so the UI can decide
|
||||
# whether to show the post-signup onboarding form on boot.
|
||||
"onboarding_completed_at": config.onboarding_completed_at,
|
||||
"onboarding_skipped": config.onboarding_skipped,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -113,13 +113,6 @@ def merge_user_configurations(
|
|||
if "timezone" in incoming_partial:
|
||||
merged["timezone"] = incoming_partial["timezone"]
|
||||
|
||||
# Onboarding gate flags: overwrite only when supplied.
|
||||
if "onboarding_completed_at" in incoming_partial:
|
||||
merged["onboarding_completed_at"] = incoming_partial["onboarding_completed_at"]
|
||||
|
||||
if "onboarding_skipped" in incoming_partial:
|
||||
merged["onboarding_skipped"] = incoming_partial["onboarding_skipped"]
|
||||
|
||||
return EffectiveAIModelConfiguration.model_validate(merged)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -257,12 +257,12 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config(
|
|||
)
|
||||
AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure Speech Services",
|
||||
description="Azure Cognitive Services Speech - TTS and STT via the Azure Speech SDK.",
|
||||
description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.",
|
||||
provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
|
||||
)
|
||||
AZURE_REALTIME_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure OpenAI Realtime",
|
||||
description="Azure OpenAI Realtime API - low-latency speech-to-speech conversations.",
|
||||
description="Azure OpenAI Realtime API — low-latency speech-to-speech conversations.",
|
||||
provider_docs_url="https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/realtime-audio-quickstart",
|
||||
)
|
||||
|
||||
|
|
@ -360,7 +360,7 @@ class GoogleVertexLLMConfiguration(BaseLLMConfiguration):
|
|||
api_key: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Not used for Vertex AI - authentication is via the service account "
|
||||
"Not used for Vertex AI — authentication is via the service account "
|
||||
"in `credentials` (or ADC). Leave blank."
|
||||
),
|
||||
)
|
||||
|
|
@ -425,7 +425,7 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
|
|||
provider: Literal[ServiceProviders.AWS_BEDROCK] = ServiceProviders.AWS_BEDROCK
|
||||
model: str = Field(
|
||||
default="us.amazon.nova-pro-v1:0",
|
||||
description="Bedrock model ID - include the region inference-profile prefix (e.g. 'us.').",
|
||||
description="Bedrock model ID — include the region inference-profile prefix (e.g. 'us.').",
|
||||
json_schema_extra={"examples": AWS_BEDROCK_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
aws_access_key: str = Field(
|
||||
|
|
@ -442,7 +442,7 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
api_key: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="Not used for Bedrock - authentication is via the AWS credentials above. Leave blank.",
|
||||
description="Not used for Bedrock — authentication is via the AWS credentials above. Leave blank.",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -682,7 +682,7 @@ class GoogleVertexRealtimeLLMConfiguration(BaseLLMConfiguration):
|
|||
api_key: str | list[str] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Not used for Vertex AI - authentication is via the service account "
|
||||
"Not used for Vertex AI — authentication is via the service account "
|
||||
"in `credentials` (or ADC). Leave blank."
|
||||
),
|
||||
)
|
||||
|
|
|
|||
37
api/services/user_onboarding.py
Normal file
37
api/services/user_onboarding.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import UserConfigurationKey
|
||||
from api.schemas.onboarding_state import OnboardingState, OnboardingStateUpdate
|
||||
|
||||
|
||||
async def get_onboarding_state(user_id: int) -> OnboardingState:
|
||||
value = await db_client.get_user_configuration_value(
|
||||
user_id, UserConfigurationKey.ONBOARDING.value
|
||||
)
|
||||
return _parse_state(value, user_id)
|
||||
|
||||
|
||||
async def update_onboarding_state(
|
||||
user_id: int, update: OnboardingStateUpdate
|
||||
) -> OnboardingState:
|
||||
state = update.apply_to(await get_onboarding_state(user_id))
|
||||
await db_client.upsert_user_configuration_value(
|
||||
user_id,
|
||||
UserConfigurationKey.ONBOARDING.value,
|
||||
state.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _parse_state(value, user_id: int) -> OnboardingState:
|
||||
if not value or not isinstance(value, dict):
|
||||
return OnboardingState()
|
||||
try:
|
||||
return OnboardingState.model_validate(value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
f"Invalid onboarding state for user {user_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OnboardingState()
|
||||
|
|
@ -176,7 +176,7 @@ class _ToolDocumentRefsMixin(BaseModel):
|
|||
@node_spec(
|
||||
name="startCall",
|
||||
display_name="Start Call",
|
||||
description="Entry point of the workflow - plays a greeting and opens the conversation.",
|
||||
description="Entry point of the workflow — plays a greeting and opens the conversation.",
|
||||
llm_hint=(
|
||||
"The entry point of every workflow (exactly one required). Plays an "
|
||||
"optional greeting, can fetch context from an external API before the "
|
||||
|
|
@ -344,7 +344,7 @@ class StartCallNodeData(
|
|||
@node_spec(
|
||||
name="agentNode",
|
||||
display_name="Agent Node",
|
||||
description="Conversational step - the LLM runs one focused exchange.",
|
||||
description="Conversational step — the LLM runs one focused exchange.",
|
||||
llm_hint=(
|
||||
"Mid-call step executed by the LLM. Most workflows are a chain of agent "
|
||||
"nodes connected by edges that describe transition conditions. Each agent "
|
||||
|
|
@ -613,9 +613,9 @@ class GlobalNodeData(BaseNodeData, _PromptedNodeDataMixin):
|
|||
"description": (
|
||||
"Path segment that uniquely identifies "
|
||||
"this trigger. Used in both URLs:\n"
|
||||
" • Production: `/api/v1/public/agent/<trigger_path>` - executes "
|
||||
" • Production: `/api/v1/public/agent/<trigger_path>` — executes "
|
||||
"the published agent.\n"
|
||||
" • Test: `/api/v1/public/agent/test/<trigger_path>` - executes "
|
||||
" • Test: `/api/v1/public/agent/test/<trigger_path>` — executes "
|
||||
"the latest draft.\n"
|
||||
"Can be customized to a descriptive value up to 36 characters "
|
||||
"using letters, numbers, hyphens, or underscores."
|
||||
|
|
@ -708,7 +708,7 @@ class TriggerNodeData(BaseNodeData):
|
|||
"display_name": "Payload Template",
|
||||
"description": (
|
||||
"JSON body of the request. Values are Jinja-rendered against the "
|
||||
"run context - `{{workflow_run_id}}`, `{{gathered_context.foo}}`, "
|
||||
"run context — `{{workflow_run_id}}`, `{{gathered_context.foo}}`, "
|
||||
"`{{annotations.qa_xxx}}`, etc."
|
||||
),
|
||||
"ui_type": PropertyType.json,
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class WorkflowGraph:
|
|||
kind=ItemKind.workflow,
|
||||
id=None,
|
||||
field=None,
|
||||
message="Workflow has no start node - exactly one is required",
|
||||
message="Workflow has no start node — exactly one is required",
|
||||
)
|
||||
)
|
||||
elif len(start_nodes) > 1:
|
||||
|
|
@ -239,7 +239,7 @@ class WorkflowGraph:
|
|||
id=None,
|
||||
field=None,
|
||||
message=(
|
||||
f"Workflow has {len(start_nodes)} start nodes - "
|
||||
f"Workflow has {len(start_nodes)} start nodes — "
|
||||
f"exactly one is required"
|
||||
),
|
||||
)
|
||||
|
|
|
|||
131
api/tests/test_onboarding_state.py
Normal file
131
api/tests/test_onboarding_state.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.onboarding_state import OnboardingState, OnboardingStateUpdate
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
|
||||
def _make_test_app():
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.is_superuser = False
|
||||
mock_user.selected_organization_id = None
|
||||
|
||||
app.dependency_overrides[get_user] = lambda: mock_user
|
||||
return app
|
||||
|
||||
|
||||
class TestOnboardingStateUpdateMerge:
|
||||
def test_lists_union_without_duplicates(self):
|
||||
state = OnboardingState(
|
||||
seen_tooltips=["web_call"], completed_actions=["web_call_started"]
|
||||
)
|
||||
update = OnboardingStateUpdate(
|
||||
seen_tooltips=["web_call", "customize_workflow"],
|
||||
completed_actions=["welcome_form_completed"],
|
||||
)
|
||||
|
||||
merged = update.apply_to(state)
|
||||
|
||||
assert merged.seen_tooltips == ["web_call", "customize_workflow"]
|
||||
assert merged.completed_actions == [
|
||||
"web_call_started",
|
||||
"welcome_form_completed",
|
||||
]
|
||||
|
||||
def test_omitted_fields_preserve_existing_state(self):
|
||||
completed_at = datetime(2026, 6, 12, tzinfo=UTC)
|
||||
state = OnboardingState(
|
||||
completed_at=completed_at, skipped=True, seen_tooltips=["web_call"]
|
||||
)
|
||||
|
||||
merged = OnboardingStateUpdate().apply_to(state)
|
||||
|
||||
assert merged.completed_at == completed_at
|
||||
assert merged.skipped is True
|
||||
assert merged.seen_tooltips == ["web_call"]
|
||||
|
||||
def test_scalars_overwrite_when_supplied(self):
|
||||
state = OnboardingState()
|
||||
completed_at = datetime(2026, 6, 12, tzinfo=UTC)
|
||||
|
||||
merged = OnboardingStateUpdate(
|
||||
completed_at=completed_at, skipped=True
|
||||
).apply_to(state)
|
||||
|
||||
assert merged.completed_at == completed_at
|
||||
assert merged.skipped is True
|
||||
|
||||
|
||||
class TestOnboardingStateRoutes:
|
||||
def test_get_returns_defaults_when_no_row(self):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(
|
||||
"api.services.user_onboarding.db_client.get_user_configuration_value",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
response = client.get("/user/onboarding-state")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["completed_at"] is None
|
||||
assert body["skipped"] is False
|
||||
assert body["seen_tooltips"] == []
|
||||
assert body["completed_actions"] == []
|
||||
|
||||
def test_get_returns_defaults_on_invalid_stored_value(self):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(
|
||||
"api.services.user_onboarding.db_client.get_user_configuration_value",
|
||||
new=AsyncMock(return_value={"skipped": "not-a-bool"}),
|
||||
):
|
||||
response = client.get("/user/onboarding-state")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["skipped"] is False
|
||||
|
||||
def test_put_merges_into_stored_state_and_persists(self):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
existing = {"seen_tooltips": ["web_call"]}
|
||||
upsert = AsyncMock(side_effect=lambda user_id, key, value: value)
|
||||
with (
|
||||
patch(
|
||||
"api.services.user_onboarding.db_client.get_user_configuration_value",
|
||||
new=AsyncMock(return_value=existing),
|
||||
),
|
||||
patch(
|
||||
"api.services.user_onboarding.db_client.upsert_user_configuration_value",
|
||||
new=upsert,
|
||||
),
|
||||
):
|
||||
response = client.put(
|
||||
"/user/onboarding-state",
|
||||
json={
|
||||
"completed_at": "2026-06-12T00:00:00Z",
|
||||
"seen_tooltips": ["customize_workflow"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["seen_tooltips"] == ["web_call", "customize_workflow"]
|
||||
assert body["completed_at"] is not None
|
||||
|
||||
upsert.assert_awaited_once()
|
||||
user_id, key, stored = upsert.await_args.args
|
||||
assert user_id == 1
|
||||
assert key == "ONBOARDING"
|
||||
assert stored["seen_tooltips"] == ["web_call", "customize_workflow"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue