mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
chore: centralize user org selection
This commit is contained in:
parent
94686b73c4
commit
8b3f060012
4 changed files with 68 additions and 37 deletions
|
|
@ -34,7 +34,7 @@ from api.schemas.telephony_phone_number import (
|
|||
PhoneNumberUpdateRequest,
|
||||
ProviderSyncStatus,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user, get_user_with_selected_organization
|
||||
from api.services.configuration.ai_model_configuration import (
|
||||
check_for_masked_keys_in_ai_model_configuration_v2,
|
||||
compile_ai_model_configuration_v2,
|
||||
|
|
@ -193,12 +193,6 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _require_selected_organization(user: UserModel) -> int:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
|
|
@ -231,8 +225,9 @@ async def _model_configuration_v2_response(
|
|||
|
||||
|
||||
@router.get("/model-configurations/v2/defaults")
|
||||
async def get_model_configuration_v2_defaults(user: UserModel = Depends(get_user)):
|
||||
_require_selected_organization(user)
|
||||
async def get_model_configuration_v2_defaults(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
byok_default_providers = {
|
||||
service: provider
|
||||
for service, provider in DEFAULT_SERVICE_PROVIDERS.items()
|
||||
|
|
@ -271,8 +266,9 @@ async def get_model_configuration_v2_defaults(user: UserModel = Depends(get_user
|
|||
"/model-configurations/v2",
|
||||
response_model=OrganizationAIModelConfigurationResponse,
|
||||
)
|
||||
async def get_model_configuration_v2(user: UserModel = Depends(get_user)):
|
||||
_require_selected_organization(user)
|
||||
async def get_model_configuration_v2(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
return await _model_configuration_v2_response(user=user)
|
||||
|
||||
|
||||
|
|
@ -282,9 +278,9 @@ async def get_model_configuration_v2(user: UserModel = Depends(get_user)):
|
|||
)
|
||||
async def save_model_configuration_v2(
|
||||
request: OrganizationAIModelConfigurationV2,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = _require_selected_organization(user)
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
configuration = merge_ai_model_configuration_v2_secrets(request, existing)
|
||||
try:
|
||||
|
|
@ -309,8 +305,9 @@ async def save_model_configuration_v2(
|
|||
|
||||
|
||||
@router.get("/model-configurations/v2/migration-preview")
|
||||
async def preview_model_configuration_v2_migration(user: UserModel = Depends(get_user)):
|
||||
_require_selected_organization(user)
|
||||
async def preview_model_configuration_v2_migration(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
legacy = await db_client.get_user_configurations(user.id)
|
||||
try:
|
||||
configuration = convert_legacy_ai_model_configuration_to_v2(legacy)
|
||||
|
|
@ -330,9 +327,9 @@ async def preview_model_configuration_v2_migration(user: UserModel = Depends(get
|
|||
)
|
||||
async def migrate_model_configuration_v2(
|
||||
force: bool = Query(default=False),
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = _require_selected_organization(user)
|
||||
organization_id = user.selected_organization_id
|
||||
existing = await get_organization_ai_model_configuration_v2(organization_id)
|
||||
if existing is not None and not force:
|
||||
raise HTTPException(
|
||||
|
|
@ -370,8 +367,10 @@ async def migrate_model_configuration_v2(
|
|||
"/model-configurations/preferences",
|
||||
response_model=OrganizationAIModelConfigurationPreferences,
|
||||
)
|
||||
async def get_model_configuration_preferences(user: UserModel = Depends(get_user)):
|
||||
organization_id = _require_selected_organization(user)
|
||||
async def get_model_configuration_preferences(
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = user.selected_organization_id
|
||||
return await get_organization_ai_model_configuration_preferences(organization_id)
|
||||
|
||||
|
||||
|
|
@ -381,9 +380,9 @@ async def get_model_configuration_preferences(user: UserModel = Depends(get_user
|
|||
)
|
||||
async def save_model_configuration_preferences(
|
||||
request: OrganizationAIModelConfigurationPreferences,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
):
|
||||
organization_id = _require_selected_organization(user)
|
||||
organization_id = user.selected_organization_id
|
||||
return await upsert_organization_ai_model_configuration_preferences(
|
||||
organization_id,
|
||||
request,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel, WorkflowRunTextSessionModel
|
||||
from api.enums import WorkflowRunMode
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.auth.depends import get_user_with_selected_organization
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.workflow.text_chat_session_service import (
|
||||
TextChatPendingTurnLostError,
|
||||
|
|
@ -96,12 +96,6 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _require_selected_organization_id(user: UserModel) -> int:
|
||||
if user.selected_organization_id is None:
|
||||
raise HTTPException(status_code=403, detail="Organization context is required")
|
||||
return user.selected_organization_id
|
||||
|
||||
|
||||
async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None:
|
||||
quota_result = await check_dograh_quota(user, workflow_id=workflow_id)
|
||||
if not quota_result.has_quota:
|
||||
|
|
@ -114,9 +108,8 @@ async def _load_text_session_or_404(
|
|||
user: UserModel,
|
||||
) -> WorkflowRunTextSessionModel:
|
||||
set_current_run_id(run_id)
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
text_session = await db_client.get_workflow_run_text_session(
|
||||
run_id, organization_id=organization_id
|
||||
run_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
if not text_session or not text_session.workflow_run:
|
||||
raise HTTPException(status_code=404, detail="Text chat session not found")
|
||||
|
|
@ -158,9 +151,8 @@ async def _execute_pending_turn_response(
|
|||
async def create_text_chat_session(
|
||||
workflow_id: int,
|
||||
request: CreateTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
organization_id = _require_selected_organization_id(user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
||||
session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}"
|
||||
|
|
@ -172,7 +164,7 @@ async def create_text_chat_session(
|
|||
user_id=user.id,
|
||||
initial_context=request.initial_context,
|
||||
use_draft=True,
|
||||
organization_id=organization_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
|
@ -220,7 +212,7 @@ async def create_text_chat_session(
|
|||
async def get_text_chat_session(
|
||||
workflow_id: int,
|
||||
run_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
return _build_response(text_session)
|
||||
|
|
@ -234,7 +226,7 @@ async def append_text_chat_message(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: AppendTextChatMessageRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
await _ensure_text_chat_quota(user, workflow_id)
|
||||
|
|
@ -264,7 +256,7 @@ async def rewind_text_chat_session(
|
|||
workflow_id: int,
|
||||
run_id: int,
|
||||
request: RewindTextChatSessionRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
user: UserModel = Depends(get_user_with_selected_organization),
|
||||
) -> WorkflowRunTextSessionResponse:
|
||||
text_session = await _load_text_session_or_404(workflow_id, run_id, user)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Annotated, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import Header, HTTPException, Query, WebSocket
|
||||
from fastapi import Depends, Header, HTTPException, Query, WebSocket
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -142,6 +142,14 @@ async def get_user(
|
|||
return user_model
|
||||
|
||||
|
||||
async def get_user_with_selected_organization(
|
||||
user: Annotated[UserModel, Depends(get_user)],
|
||||
) -> UserModel:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(status_code=400, detail="No organization selected")
|
||||
return user
|
||||
|
||||
|
||||
async def _handle_oss_auth(authorization: str | None) -> UserModel:
|
||||
"""
|
||||
Handle authentication for OSS deployment mode.
|
||||
|
|
|
|||
|
|
@ -51,6 +51,38 @@ async def _create_user_and_workflow(
|
|||
return user, workflow
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_requires_selected_organization():
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from api.app import app
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
user = UserModel(provider_id="textchat-user-no-selected-org")
|
||||
|
||||
async def mock_get_user():
|
||||
return user
|
||||
|
||||
original_override = app.dependency_overrides.get(get_user)
|
||||
app.dependency_overrides[get_user] = mock_get_user
|
||||
|
||||
try:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/v1/workflow/123/text-chat/sessions", json={}
|
||||
)
|
||||
finally:
|
||||
if original_override:
|
||||
app.dependency_overrides[get_user] = original_override
|
||||
else:
|
||||
app.dependency_overrides.pop(get_user, None)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "No organization selected"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chat_session_creation_executes_initial_assistant_turn(
|
||||
db_session,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue