mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-19 08:28:10 +02:00
Merge branch 'main' of https://github.com/dograh-hq/dograh
This commit is contained in:
commit
9a1b980f91
81 changed files with 3817 additions and 602 deletions
|
|
@ -0,0 +1,52 @@
|
|||
"""add key to user_configurations
|
||||
|
||||
Turns user_configurations into a per-user keyed JSON store mirroring
|
||||
organization_configurations. Existing rows (the legacy v1 AI model
|
||||
configuration blob) are backfilled with key MODEL_CONFIGURATION.
|
||||
|
||||
Revision ID: 91cc6ba3e1c7
|
||||
Revises: efe356f488f9
|
||||
Create Date: 2026-06-12 21:04:25.561529
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "91cc6ba3e1c7"
|
||||
down_revision: Union[str, None] = "efe356f488f9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Backfill existing rows (all legacy model-config blobs) via the server
|
||||
# default, then drop the default — application code always supplies key.
|
||||
op.add_column(
|
||||
"user_configurations",
|
||||
sa.Column(
|
||||
"key",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="MODEL_CONFIGURATION",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_unique_constraint(
|
||||
"_user_configuration_key_uc", "user_configurations", ["user_id", "key"]
|
||||
)
|
||||
op.alter_column("user_configurations", "key", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"_user_configuration_key_uc", "user_configurations", type_="unique"
|
||||
)
|
||||
# Non-model-config rows (e.g. ONBOARDING) have no meaning in the old
|
||||
# single-blob schema; the old code would read them as the user's model
|
||||
# config, so they must not survive the downgrade.
|
||||
op.execute("DELETE FROM user_configurations WHERE key != 'MODEL_CONFIGURATION'")
|
||||
op.drop_column("user_configurations", "key")
|
||||
|
|
@ -82,12 +82,24 @@ class UserModel(Base):
|
|||
|
||||
|
||||
class UserConfigurationModel(Base):
|
||||
"""Per-user keyed JSON store, mirroring organization_configurations.
|
||||
|
||||
Keys are defined in UserConfigurationKey. The legacy v1 AI model
|
||||
configuration lives under MODEL_CONFIGURATION; last_validated_at is only
|
||||
meaningful for that key.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_configurations"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
key = Column(String, nullable=False)
|
||||
configuration = Column(JSON, nullable=False, default=dict)
|
||||
last_validated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "key", name="_user_configuration_key_uc"),
|
||||
)
|
||||
|
||||
|
||||
# New Organization model
|
||||
class OrganizationModel(Base):
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from api.db.models import (
|
|||
WorkflowModel,
|
||||
WorkflowRunModel,
|
||||
)
|
||||
from api.enums import OrganizationConfigurationKey
|
||||
from api.enums import OrganizationConfigurationKey, UserConfigurationKey
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.utils.recording_artifacts import get_recording_storage_key
|
||||
|
||||
|
|
@ -347,7 +347,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
if user_id:
|
||||
config_result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
UserConfigurationModel.user_id == user_id,
|
||||
UserConfigurationModel.key
|
||||
== UserConfigurationKey.MODEL_CONFIGURATION.value,
|
||||
)
|
||||
)
|
||||
config_obj = config_result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.future import select
|
|||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import UserConfigurationModel, UserModel
|
||||
from api.enums import UserConfigurationKey
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
|
||||
|
||||
|
|
@ -65,16 +66,51 @@ class UserClient(BaseDBClient):
|
|||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _get_user_configuration_row(
|
||||
self, session, user_id: int, key: str
|
||||
) -> UserConfigurationModel | None:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id,
|
||||
UserConfigurationModel.key == key,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_configuration_value(self, user_id: int, key: str) -> dict | None:
|
||||
"""Get the JSON value stored for a user under `key`, or None."""
|
||||
async with self.async_session() as session:
|
||||
row = await self._get_user_configuration_row(session, user_id, key)
|
||||
return row.configuration if row else None
|
||||
|
||||
async def upsert_user_configuration_value(
|
||||
self, user_id: int, key: str, value: dict
|
||||
) -> dict:
|
||||
"""Create or update the JSON value stored for a user under `key`."""
|
||||
async with self.async_session() as session:
|
||||
row = await self._get_user_configuration_row(session, user_id, key)
|
||||
if row:
|
||||
row.configuration = value
|
||||
else:
|
||||
row = UserConfigurationModel(
|
||||
user_id=user_id, key=key, configuration=value
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(row)
|
||||
return row.configuration
|
||||
|
||||
async def get_user_configurations(
|
||||
self, user_id: int
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
configuration_obj = await self._get_user_configuration_row(
|
||||
session, user_id, UserConfigurationKey.MODEL_CONFIGURATION.value
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
return EffectiveAIModelConfiguration()
|
||||
|
||||
|
|
@ -97,38 +133,18 @@ class UserClient(BaseDBClient):
|
|||
async def update_user_configuration(
|
||||
self, user_id: int, configuration: EffectiveAIModelConfiguration
|
||||
) -> EffectiveAIModelConfiguration:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
configuration_obj = UserConfigurationModel(
|
||||
user_id=user_id, configuration=configuration.model_dump()
|
||||
)
|
||||
session.add(configuration_obj)
|
||||
else:
|
||||
configuration_obj.configuration = configuration.model_dump()
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
await session.refresh(configuration_obj)
|
||||
return EffectiveAIModelConfiguration.model_validate(
|
||||
configuration_obj.configuration
|
||||
value = await self.upsert_user_configuration_value(
|
||||
user_id,
|
||||
UserConfigurationKey.MODEL_CONFIGURATION.value,
|
||||
configuration.model_dump(),
|
||||
)
|
||||
return EffectiveAIModelConfiguration.model_validate(value)
|
||||
|
||||
async def update_user_configuration_last_validated_at(self, user_id: int) -> None:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserConfigurationModel).where(
|
||||
UserConfigurationModel.user_id == user_id
|
||||
)
|
||||
configuration_obj = await self._get_user_configuration_row(
|
||||
session, user_id, UserConfigurationKey.MODEL_CONFIGURATION.value
|
||||
)
|
||||
configuration_obj = result.scalars().first()
|
||||
if not configuration_obj:
|
||||
raise ValueError(f"User configuration with ID {user_id} not found")
|
||||
configuration_obj.last_validated_at = datetime.now()
|
||||
|
|
|
|||
|
|
@ -96,6 +96,15 @@ class OrganizationConfigurationKey(Enum):
|
|||
MODEL_CONFIGURATION_PREFERENCES = "MODEL_CONFIGURATION_PREFERENCES" # Deprecated; read fallback for old org preferences
|
||||
|
||||
|
||||
class UserConfigurationKey(Enum):
|
||||
"""Keys for the per-user keyed JSON store (user_configurations)."""
|
||||
|
||||
MODEL_CONFIGURATION = (
|
||||
"MODEL_CONFIGURATION" # Legacy per-user v1 AI model configuration
|
||||
)
|
||||
ONBOARDING = "ONBOARDING" # Post-signup onboarding state (gate, tooltips, actions)
|
||||
|
||||
|
||||
class WorkflowStatus(Enum):
|
||||
"""Workflow status values"""
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.db import db_client
|
|||
from api.db.models import (
|
||||
UserModel,
|
||||
)
|
||||
from api.schemas.onboarding_state import OnboardingState, OnboardingStateUpdate
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
get_resolved_ai_model_configuration,
|
||||
|
|
@ -26,6 +27,10 @@ from api.services.organization_preferences import (
|
|||
get_organization_preferences,
|
||||
upsert_organization_preferences,
|
||||
)
|
||||
from api.services.user_onboarding import (
|
||||
get_onboarding_state,
|
||||
update_onboarding_state,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/user")
|
||||
|
||||
|
|
@ -203,6 +208,21 @@ async def update_user_configurations(
|
|||
return masked_config
|
||||
|
||||
|
||||
@router.get("/onboarding-state")
|
||||
async def get_user_onboarding_state(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> OnboardingState:
|
||||
return await get_onboarding_state(user.id)
|
||||
|
||||
|
||||
@router.put("/onboarding-state")
|
||||
async def update_user_onboarding_state(
|
||||
request: OnboardingStateUpdate,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> OnboardingState:
|
||||
return await update_onboarding_state(user.id, request)
|
||||
|
||||
|
||||
@router.get("/configurations/user/validate")
|
||||
async def validate_user_configurations(
|
||||
validity_ttl_seconds: int = Query(default=60, ge=0, le=86400),
|
||||
|
|
|
|||
47
api/schemas/onboarding_state.py
Normal file
47
api/schemas/onboarding_state.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OnboardingState(BaseModel):
|
||||
"""Per-user onboarding state, stored under UserConfigurationKey.ONBOARDING.
|
||||
|
||||
Server-authoritative replacement for the browser-localStorage onboarding
|
||||
store, so the post-signup gate and one-time tooltips hold across devices.
|
||||
"""
|
||||
|
||||
# Post-signup onboarding form gate: set once on submit/skip.
|
||||
completed_at: datetime | None = None
|
||||
skipped: bool = False
|
||||
# One-time UI affordances (tooltip keys, milestone action keys). Kept as
|
||||
# free-form strings — the UI owns the vocabulary.
|
||||
seen_tooltips: list[str] = Field(default_factory=list)
|
||||
completed_actions: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OnboardingStateUpdate(BaseModel):
|
||||
"""Partial update merged into the stored state.
|
||||
|
||||
Scalars overwrite when supplied; list entries are unioned into the stored
|
||||
lists, so concurrent updates (e.g. two tabs marking different tooltips)
|
||||
don't drop each other's items.
|
||||
"""
|
||||
|
||||
completed_at: datetime | None = None
|
||||
skipped: bool | None = None
|
||||
seen_tooltips: list[str] | None = None
|
||||
completed_actions: list[str] | None = None
|
||||
|
||||
def apply_to(self, state: OnboardingState) -> OnboardingState:
|
||||
merged = state.model_copy(deep=True)
|
||||
if self.completed_at is not None:
|
||||
merged.completed_at = self.completed_at
|
||||
if self.skipped is not None:
|
||||
merged.skipped = self.skipped
|
||||
for tooltip in self.seen_tooltips or []:
|
||||
if tooltip not in merged.seen_tooltips:
|
||||
merged.seen_tooltips.append(tooltip)
|
||||
for action in self.completed_actions or []:
|
||||
if action not in merged.completed_actions:
|
||||
merged.completed_actions.append(action)
|
||||
return merged
|
||||
37
api/services/user_onboarding.py
Normal file
37
api/services/user_onboarding.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import UserConfigurationKey
|
||||
from api.schemas.onboarding_state import OnboardingState, OnboardingStateUpdate
|
||||
|
||||
|
||||
async def get_onboarding_state(user_id: int) -> OnboardingState:
|
||||
value = await db_client.get_user_configuration_value(
|
||||
user_id, UserConfigurationKey.ONBOARDING.value
|
||||
)
|
||||
return _parse_state(value, user_id)
|
||||
|
||||
|
||||
async def update_onboarding_state(
|
||||
user_id: int, update: OnboardingStateUpdate
|
||||
) -> OnboardingState:
|
||||
state = update.apply_to(await get_onboarding_state(user_id))
|
||||
await db_client.upsert_user_configuration_value(
|
||||
user_id,
|
||||
UserConfigurationKey.ONBOARDING.value,
|
||||
state.model_dump(mode="json", exclude_none=True),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _parse_state(value, user_id: int) -> OnboardingState:
|
||||
if not value or not isinstance(value, dict):
|
||||
return OnboardingState()
|
||||
try:
|
||||
return OnboardingState.model_validate(value)
|
||||
except ValidationError as exc:
|
||||
logger.warning(
|
||||
f"Invalid onboarding state for user {user_id}: {exc}. Returning defaults."
|
||||
)
|
||||
return OnboardingState()
|
||||
131
api/tests/test_onboarding_state.py
Normal file
131
api/tests/test_onboarding_state.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.routes.user import router
|
||||
from api.schemas.onboarding_state import OnboardingState, OnboardingStateUpdate
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
|
||||
def _make_test_app():
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 1
|
||||
mock_user.is_superuser = False
|
||||
mock_user.selected_organization_id = None
|
||||
|
||||
app.dependency_overrides[get_user] = lambda: mock_user
|
||||
return app
|
||||
|
||||
|
||||
class TestOnboardingStateUpdateMerge:
|
||||
def test_lists_union_without_duplicates(self):
|
||||
state = OnboardingState(
|
||||
seen_tooltips=["web_call"], completed_actions=["web_call_started"]
|
||||
)
|
||||
update = OnboardingStateUpdate(
|
||||
seen_tooltips=["web_call", "customize_workflow"],
|
||||
completed_actions=["welcome_form_completed"],
|
||||
)
|
||||
|
||||
merged = update.apply_to(state)
|
||||
|
||||
assert merged.seen_tooltips == ["web_call", "customize_workflow"]
|
||||
assert merged.completed_actions == [
|
||||
"web_call_started",
|
||||
"welcome_form_completed",
|
||||
]
|
||||
|
||||
def test_omitted_fields_preserve_existing_state(self):
|
||||
completed_at = datetime(2026, 6, 12, tzinfo=UTC)
|
||||
state = OnboardingState(
|
||||
completed_at=completed_at, skipped=True, seen_tooltips=["web_call"]
|
||||
)
|
||||
|
||||
merged = OnboardingStateUpdate().apply_to(state)
|
||||
|
||||
assert merged.completed_at == completed_at
|
||||
assert merged.skipped is True
|
||||
assert merged.seen_tooltips == ["web_call"]
|
||||
|
||||
def test_scalars_overwrite_when_supplied(self):
|
||||
state = OnboardingState()
|
||||
completed_at = datetime(2026, 6, 12, tzinfo=UTC)
|
||||
|
||||
merged = OnboardingStateUpdate(
|
||||
completed_at=completed_at, skipped=True
|
||||
).apply_to(state)
|
||||
|
||||
assert merged.completed_at == completed_at
|
||||
assert merged.skipped is True
|
||||
|
||||
|
||||
class TestOnboardingStateRoutes:
|
||||
def test_get_returns_defaults_when_no_row(self):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(
|
||||
"api.services.user_onboarding.db_client.get_user_configuration_value",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
response = client.get("/user/onboarding-state")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["completed_at"] is None
|
||||
assert body["skipped"] is False
|
||||
assert body["seen_tooltips"] == []
|
||||
assert body["completed_actions"] == []
|
||||
|
||||
def test_get_returns_defaults_on_invalid_stored_value(self):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(
|
||||
"api.services.user_onboarding.db_client.get_user_configuration_value",
|
||||
new=AsyncMock(return_value={"skipped": "not-a-bool"}),
|
||||
):
|
||||
response = client.get("/user/onboarding-state")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["skipped"] is False
|
||||
|
||||
def test_put_merges_into_stored_state_and_persists(self):
|
||||
app = _make_test_app()
|
||||
client = TestClient(app)
|
||||
|
||||
existing = {"seen_tooltips": ["web_call"]}
|
||||
upsert = AsyncMock(side_effect=lambda user_id, key, value: value)
|
||||
with (
|
||||
patch(
|
||||
"api.services.user_onboarding.db_client.get_user_configuration_value",
|
||||
new=AsyncMock(return_value=existing),
|
||||
),
|
||||
patch(
|
||||
"api.services.user_onboarding.db_client.upsert_user_configuration_value",
|
||||
new=upsert,
|
||||
),
|
||||
):
|
||||
response = client.put(
|
||||
"/user/onboarding-state",
|
||||
json={
|
||||
"completed_at": "2026-06-12T00:00:00Z",
|
||||
"seen_tooltips": ["customize_workflow"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["seen_tooltips"] == ["web_call", "customize_workflow"]
|
||||
assert body["completed_at"] is not None
|
||||
|
||||
upsert.assert_awaited_once()
|
||||
user_id, key, stored = upsert.await_args.args
|
||||
assert user_id == 1
|
||||
assert key == "ONBOARDING"
|
||||
assert stored["seen_tooltips"] == ["web_call", "customize_workflow"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue