diff --git a/api/routes/organization.py b/api/routes/organization.py index 4006045..c132649 100644 --- a/api/routes/organization.py +++ b/api/routes/organization.py @@ -34,7 +34,7 @@ from api.schemas.telephony_phone_number import ( PhoneNumberUpdateRequest, ProviderSyncStatus, ) -from api.services.auth.depends import get_user +from api.services.auth.depends import get_user, get_user_with_selected_organization from api.services.configuration.ai_model_configuration import ( check_for_masked_keys_in_ai_model_configuration_v2, compile_ai_model_configuration_v2, @@ -193,12 +193,6 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)): # --------------------------------------------------------------------------- -def _require_selected_organization(user: UserModel) -> int: - if not user.selected_organization_id: - raise HTTPException(status_code=400, detail="No organization selected") - return user.selected_organization_id - - def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]: return { provider: model_cls.model_json_schema() @@ -231,8 +225,9 @@ async def _model_configuration_v2_response( @router.get("/model-configurations/v2/defaults") -async def get_model_configuration_v2_defaults(user: UserModel = Depends(get_user)): - _require_selected_organization(user) +async def get_model_configuration_v2_defaults( + user: UserModel = Depends(get_user_with_selected_organization), +): byok_default_providers = { service: provider for service, provider in DEFAULT_SERVICE_PROVIDERS.items() @@ -271,8 +266,9 @@ async def get_model_configuration_v2_defaults(user: UserModel = Depends(get_user "/model-configurations/v2", response_model=OrganizationAIModelConfigurationResponse, ) -async def get_model_configuration_v2(user: UserModel = Depends(get_user)): - _require_selected_organization(user) +async def get_model_configuration_v2( + user: UserModel = Depends(get_user_with_selected_organization), +): return await _model_configuration_v2_response(user=user) @@ -282,9 +278,9 @@ async def get_model_configuration_v2(user: UserModel = Depends(get_user)): ) async def save_model_configuration_v2( request: OrganizationAIModelConfigurationV2, - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ): - organization_id = _require_selected_organization(user) + organization_id = user.selected_organization_id existing = await get_organization_ai_model_configuration_v2(organization_id) configuration = merge_ai_model_configuration_v2_secrets(request, existing) try: @@ -309,8 +305,9 @@ async def save_model_configuration_v2( @router.get("/model-configurations/v2/migration-preview") -async def preview_model_configuration_v2_migration(user: UserModel = Depends(get_user)): - _require_selected_organization(user) +async def preview_model_configuration_v2_migration( + user: UserModel = Depends(get_user_with_selected_organization), +): legacy = await db_client.get_user_configurations(user.id) try: configuration = convert_legacy_ai_model_configuration_to_v2(legacy) @@ -330,9 +327,9 @@ async def preview_model_configuration_v2_migration(user: UserModel = Depends(get ) async def migrate_model_configuration_v2( force: bool = Query(default=False), - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ): - organization_id = _require_selected_organization(user) + organization_id = user.selected_organization_id existing = await get_organization_ai_model_configuration_v2(organization_id) if existing is not None and not force: raise HTTPException( @@ -370,8 +367,10 @@ async def migrate_model_configuration_v2( "/model-configurations/preferences", response_model=OrganizationAIModelConfigurationPreferences, ) -async def get_model_configuration_preferences(user: UserModel = Depends(get_user)): - organization_id = _require_selected_organization(user) +async def get_model_configuration_preferences( + user: UserModel = Depends(get_user_with_selected_organization), +): + organization_id = user.selected_organization_id return await get_organization_ai_model_configuration_preferences(organization_id) @@ -381,9 +380,9 @@ async def get_model_configuration_preferences(user: UserModel = Depends(get_user ) async def save_model_configuration_preferences( request: OrganizationAIModelConfigurationPreferences, - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ): - organization_id = _require_selected_organization(user) + organization_id = user.selected_organization_id return await upsert_organization_ai_model_configuration_preferences( organization_id, request, diff --git a/api/routes/workflow_text_chat.py b/api/routes/workflow_text_chat.py index 71d1b90..b465011 100644 --- a/api/routes/workflow_text_chat.py +++ b/api/routes/workflow_text_chat.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field from api.db import db_client from api.db.models import UserModel, WorkflowRunTextSessionModel from api.enums import WorkflowRunMode -from api.services.auth.depends import get_user +from api.services.auth.depends import get_user_with_selected_organization from api.services.quota_service import check_dograh_quota from api.services.workflow.text_chat_session_service import ( TextChatPendingTurnLostError, @@ -96,12 +96,6 @@ def _revision_conflict_detail(e: Any) -> dict[str, Any]: } -def _require_selected_organization_id(user: UserModel) -> int: - if user.selected_organization_id is None: - raise HTTPException(status_code=403, detail="Organization context is required") - return user.selected_organization_id - - async def _ensure_text_chat_quota(user: UserModel, workflow_id: int) -> None: quota_result = await check_dograh_quota(user, workflow_id=workflow_id) if not quota_result.has_quota: @@ -114,9 +108,8 @@ async def _load_text_session_or_404( user: UserModel, ) -> WorkflowRunTextSessionModel: set_current_run_id(run_id) - organization_id = _require_selected_organization_id(user) text_session = await db_client.get_workflow_run_text_session( - run_id, organization_id=organization_id + run_id, organization_id=user.selected_organization_id ) if not text_session or not text_session.workflow_run: raise HTTPException(status_code=404, detail="Text chat session not found") @@ -158,9 +151,8 @@ async def _execute_pending_turn_response( async def create_text_chat_session( workflow_id: int, request: CreateTextChatSessionRequest, - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ) -> WorkflowRunTextSessionResponse: - organization_id = _require_selected_organization_id(user) await _ensure_text_chat_quota(user, workflow_id) session_name = request.name or f"WR-TEXT-{uuid4().hex[:6].upper()}" @@ -172,7 +164,7 @@ async def create_text_chat_session( user_id=user.id, initial_context=request.initial_context, use_draft=True, - organization_id=organization_id, + organization_id=user.selected_organization_id, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -220,7 +212,7 @@ async def create_text_chat_session( async def get_text_chat_session( workflow_id: int, run_id: int, - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ) -> WorkflowRunTextSessionResponse: text_session = await _load_text_session_or_404(workflow_id, run_id, user) return _build_response(text_session) @@ -234,7 +226,7 @@ async def append_text_chat_message( workflow_id: int, run_id: int, request: AppendTextChatMessageRequest, - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ) -> WorkflowRunTextSessionResponse: text_session = await _load_text_session_or_404(workflow_id, run_id, user) await _ensure_text_chat_quota(user, workflow_id) @@ -264,7 +256,7 @@ async def rewind_text_chat_session( workflow_id: int, run_id: int, request: RewindTextChatSessionRequest, - user: UserModel = Depends(get_user), + user: UserModel = Depends(get_user_with_selected_organization), ) -> WorkflowRunTextSessionResponse: text_session = await _load_text_session_or_404(workflow_id, run_id, user) try: diff --git a/api/services/auth/depends.py b/api/services/auth/depends.py index 328eb40..7a28492 100644 --- a/api/services/auth/depends.py +++ b/api/services/auth/depends.py @@ -1,7 +1,7 @@ from typing import Annotated, Optional import httpx -from fastapi import Header, HTTPException, Query, WebSocket +from fastapi import Depends, Header, HTTPException, Query, WebSocket from loguru import logger from pydantic import ValidationError @@ -142,6 +142,14 @@ async def get_user( return user_model +async def get_user_with_selected_organization( + user: Annotated[UserModel, Depends(get_user)], +) -> UserModel: + if not user.selected_organization_id: + raise HTTPException(status_code=400, detail="No organization selected") + return user + + async def _handle_oss_auth(authorization: str | None) -> UserModel: """ Handle authentication for OSS deployment mode. diff --git a/api/tests/test_workflow_text_chat.py b/api/tests/test_workflow_text_chat.py index 1b830bf..219f333 100644 --- a/api/tests/test_workflow_text_chat.py +++ b/api/tests/test_workflow_text_chat.py @@ -51,6 +51,38 @@ async def _create_user_and_workflow( return user, workflow +@pytest.mark.asyncio +async def test_text_chat_session_creation_requires_selected_organization(): + from httpx import ASGITransport, AsyncClient + + from api.app import app + from api.services.auth.depends import get_user + + user = UserModel(provider_id="textchat-user-no-selected-org") + + async def mock_get_user(): + return user + + original_override = app.dependency_overrides.get(get_user) + app.dependency_overrides[get_user] = mock_get_user + + try: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/v1/workflow/123/text-chat/sessions", json={} + ) + finally: + if original_override: + app.dependency_overrides[get_user] = original_override + else: + app.dependency_overrides.pop(get_user, None) + + assert response.status_code == 400 + assert response.json() == {"detail": "No organization selected"} + + @pytest.mark.asyncio async def test_text_chat_session_creation_executes_initial_assistant_turn( db_session,