feat: enforce API access for integration routes

This commit is contained in:
Anish Sarkar 2026-06-19 20:28:12 +05:30
parent 70a0828b95
commit 7ec6fa4d1f
8 changed files with 125 additions and 74 deletions

View file

@ -18,6 +18,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config as app_config from app.config import config as app_config
from app.db import ( from app.db import (
Permission, Permission,
@ -42,7 +43,7 @@ from app.podcasts.voices import (
provider_from_service, provider_from_service,
render_voice_preview, render_voice_preview,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
from .schemas import ( from .schemas import (
@ -63,8 +64,9 @@ async def list_podcasts(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
if skip < 0 or limit < 1: if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters") raise HTTPException(status_code=400, detail="Invalid pagination parameters")
@ -132,8 +134,9 @@ async def list_languages():
@router.get("/podcasts/voices/{voice_id}/preview") @router.get("/podcasts/voices/{voice_id}/preview")
async def preview_voice( async def preview_voice(
voice_id: str, voice_id: str,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""A short audio sample of a voice, so users pick by sound.""" """A short audio sample of a voice, so users pick by sound."""
if not app_config.TTS_SERVICE: if not app_config.TTS_SERVICE:
raise HTTPException(status_code=503, detail="No TTS provider configured") raise HTTPException(status_code=503, detail="No TTS provider configured")
@ -156,8 +159,9 @@ async def preview_voice(
async def create_podcast( async def create_podcast(
body: CreatePodcastRequest, body: CreatePodcastRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE) await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE)
service = PodcastService(session) service = PodcastService(session)
@ -185,8 +189,9 @@ async def create_podcast(
async def get_podcast( async def get_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
return PodcastDetail.of(podcast) return PodcastDetail.of(podcast)
@ -196,8 +201,9 @@ async def update_spec(
podcast_id: int, podcast_id: int,
body: UpdateSpecRequest, body: UpdateSpecRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).update_spec( await PodcastService(session).update_spec(
@ -211,8 +217,9 @@ async def update_spec(
async def approve_brief( async def approve_brief(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Approve the brief and start drafting the transcript.""" """Approve the brief and start drafting the transcript."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
@ -228,8 +235,9 @@ async def approve_brief(
async def regenerate_transcript( async def regenerate_transcript(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Reopen the brief gate for a fresh take; drafting waits for re-approval.""" """Reopen the brief gate for a fresh take; drafting waits for re-approval."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
@ -242,8 +250,9 @@ async def regenerate_transcript(
async def revert_regeneration( async def revert_regeneration(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Back out of a regeneration and return to the finished episode.""" """Back out of a regeneration and return to the finished episode."""
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
@ -256,8 +265,9 @@ async def revert_regeneration(
async def cancel_podcast( async def cancel_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
async with _lifecycle_errors(): async with _lifecycle_errors():
await PodcastService(session).cancel(podcast) await PodcastService(session).cancel(podcast)
@ -269,8 +279,9 @@ async def cancel_podcast(
async def delete_podcast( async def delete_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE)
await purge_audio(podcast) await purge_audio(podcast)
await session.delete(podcast) await session.delete(podcast)
@ -282,8 +293,9 @@ async def delete_podcast(
async def stream_podcast( async def stream_podcast(
podcast_id: int, podcast_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ) podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
if podcast.storage_key: if podcast.storage_key:
@ -323,13 +335,14 @@ async def stream_podcast(
async def _require( async def _require(
session: AsyncSession, session: AsyncSession,
user: User, auth: AuthContext,
search_space_id: int, search_space_id: int,
permission: Permission, permission: Permission,
) -> None: ) -> None:
user = auth.user
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
permission.value, permission.value,
"You don't have permission for podcasts in this search space", "You don't have permission for podcasts in this search space",
@ -338,10 +351,11 @@ async def _require(
async def _load( async def _load(
session: AsyncSession, session: AsyncSession,
user: User, auth: AuthContext,
podcast_id: int, podcast_id: int,
permission: Permission, permission: Permission,
) -> Podcast: ) -> Podcast:
user = auth.user
podcast = await PodcastRepository(session).get(podcast_id) podcast = await PodcastRepository(session).get(podcast_id)
if podcast is None: if podcast is None:
raise HTTPException(status_code=404, detail="Podcast not found") raise HTTPException(status_code=404, detail="Podcast not found")

View file

@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Permission, User, get_async_session from app.db import Permission, User, get_async_session
from app.services.export_service import build_export_zip from app.services.export_service import build_export_zip
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,12 +25,13 @@ async def export_knowledge_base(
None, description="Export only this folder's subtree" None, description="Export only this folder's subtree"
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Export documents as a ZIP of markdown files preserving folder structure.""" """Export documents as a ZIP of markdown files preserving folder structure."""
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to export documents in this search space", "You don't have permission to export documents in this search space",

View file

@ -20,6 +20,7 @@ from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import JSONResponse, RedirectResponse, Response from starlette.responses import JSONResponse, RedirectResponse, Response
from app.auth.context import AuthContext
from app.config import config from app.config import config
from app.db import ( from app.db import (
ExternalChatAccount, ExternalChatAccount,
@ -51,7 +52,7 @@ from app.observability.metrics import (
record_gateway_inbox_write, record_gateway_inbox_write,
record_gateway_webhook_parse_error, record_gateway_webhook_parse_error,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.oauth_security import OAuthStateManager, TokenEncryption from app.utils.oauth_security import OAuthStateManager, TokenEncryption
from app.utils.rbac import check_search_space_access from app.utils.rbac import check_search_space_access
@ -250,14 +251,15 @@ def _telegram_message(payload: dict[str, Any]) -> dict[str, Any] | None:
@router.get("/slack/install") @router.get("/slack/install")
async def install_slack_gateway( async def install_slack_gateway(
search_space_id: int, search_space_id: int,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, str]: ) -> dict[str, str]:
user = auth.user
if not _slack_gateway_enabled(): if not _slack_gateway_enabled():
raise HTTPException( raise HTTPException(
status_code=500, detail="Slack gateway OAuth is not configured" status_code=500, detail="Slack gateway OAuth is not configured"
) )
await check_search_space_access(session, user, search_space_id) await check_search_space_access(session, auth, search_space_id)
state = _get_state_manager().generate_secure_state(search_space_id, user.id) state = _get_state_manager().generate_secure_state(search_space_id, user.id)
auth_params = { auth_params = {
"client_id": config.GATEWAY_SLACK_CLIENT_ID, "client_id": config.GATEWAY_SLACK_CLIENT_ID,
@ -409,14 +411,15 @@ async def slack_gateway_callback(
@router.get("/discord/install") @router.get("/discord/install")
async def install_discord_gateway( async def install_discord_gateway(
search_space_id: int, search_space_id: int,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, str]: ) -> dict[str, str]:
user = auth.user
if not _discord_gateway_enabled(): if not _discord_gateway_enabled():
raise HTTPException( raise HTTPException(
status_code=500, detail="Discord gateway OAuth is not configured" status_code=500, detail="Discord gateway OAuth is not configured"
) )
await check_search_space_access(session, user, search_space_id) await check_search_space_access(session, auth, search_space_id)
state = _get_state_manager().generate_secure_state(search_space_id, user.id) state = _get_state_manager().generate_secure_state(search_space_id, user.id)
auth_params = { auth_params = {
"client_id": config.DISCORD_CLIENT_ID, "client_id": config.DISCORD_CLIENT_ID,
@ -712,10 +715,11 @@ async def telegram_webhook(
@router.post("/bindings/start", response_model=StartBindingResponse) @router.post("/bindings/start", response_model=StartBindingResponse)
async def start_binding( async def start_binding(
body: StartBindingRequest, body: StartBindingRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> StartBindingResponse: ) -> StartBindingResponse:
await check_search_space_access(session, user, body.search_space_id) user = auth.user
await check_search_space_access(session, auth, body.search_space_id)
code = generate_pairing_code() code = generate_pairing_code()
if body.platform == ExternalChatPlatform.TELEGRAM: if body.platform == ExternalChatPlatform.TELEGRAM:
if not _telegram_gateway_enabled(): if not _telegram_gateway_enabled():
@ -774,9 +778,10 @@ async def start_binding(
@router.get("/bindings") @router.get("/bindings")
async def list_bindings( async def list_bindings(
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
user = auth.user
result = await session.execute( result = await session.execute(
select(ExternalChatBinding, ExternalChatAccount) select(ExternalChatBinding, ExternalChatAccount)
.join( .join(
@ -803,9 +808,10 @@ async def list_bindings(
@router.get("/connections") @router.get("/connections")
async def list_connections( async def list_connections(
platform: ExternalChatPlatform | None = None, platform: ExternalChatPlatform | None = None,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
user = auth.user
active_whatsapp_mode = _active_whatsapp_account_mode() active_whatsapp_mode = _active_whatsapp_account_mode()
if platform == ExternalChatPlatform.WHATSAPP and active_whatsapp_mode is None: if platform == ExternalChatPlatform.WHATSAPP and active_whatsapp_mode is None:
return [] return []
@ -946,9 +952,10 @@ async def list_connections(
@router.get("/platforms") @router.get("/platforms")
async def list_platforms( async def list_platforms(
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
user = auth.user
result = await session.execute( result = await session.execute(
select(ExternalChatAccount).where( select(ExternalChatAccount).where(
(ExternalChatAccount.owner_user_id == user.id) (ExternalChatAccount.owner_user_id == user.id)
@ -970,8 +977,9 @@ async def list_platforms(
@config_router.get("/config") @config_router.get("/config")
async def get_gateway_config( async def get_gateway_config(
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
) -> dict[str, bool | str]: ) -> dict[str, bool | str]:
user = auth.user
if not config.GATEWAY_ENABLED: if not config.GATEWAY_ENABLED:
return { return {
"enabled": False, "enabled": False,
@ -993,9 +1001,10 @@ async def get_gateway_config(
async def update_binding_search_space( async def update_binding_search_space(
binding_id: int, binding_id: int,
body: UpdateBindingSearchSpaceRequest, body: UpdateBindingSearchSpaceRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]: ) -> dict[str, bool]:
user = auth.user
binding = await session.get(ExternalChatBinding, binding_id) binding = await session.get(ExternalChatBinding, binding_id)
if binding is None or binding.user_id != user.id: if binding is None or binding.user_id != user.id:
raise HTTPException(status_code=404, detail="Binding not found") raise HTTPException(status_code=404, detail="Binding not found")
@ -1010,7 +1019,7 @@ async def update_binding_search_space(
if account is None or _is_inactive_whatsapp_account(account): if account is None or _is_inactive_whatsapp_account(account):
raise HTTPException(status_code=404, detail="Binding not found") raise HTTPException(status_code=404, detail="Binding not found")
await check_search_space_access(session, user, body.search_space_id) await check_search_space_access(session, auth, body.search_space_id)
if binding.search_space_id != body.search_space_id: if binding.search_space_id != body.search_space_id:
binding.search_space_id = body.search_space_id binding.search_space_id = body.search_space_id
binding.new_chat_thread_id = None binding.new_chat_thread_id = None
@ -1023,9 +1032,10 @@ async def update_binding_search_space(
async def update_gateway_account_search_space( async def update_gateway_account_search_space(
account_id: int, account_id: int,
body: UpdateAccountSearchSpaceRequest, body: UpdateAccountSearchSpaceRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]: ) -> dict[str, bool]:
user = auth.user
account = await session.get(ExternalChatAccount, account_id) account = await session.get(ExternalChatAccount, account_id)
if ( if (
account is None account is None
@ -1036,7 +1046,7 @@ async def update_gateway_account_search_space(
): ):
raise HTTPException(status_code=404, detail="Gateway account not found") raise HTTPException(status_code=404, detail="Gateway account not found")
await check_search_space_access(session, user, body.search_space_id) await check_search_space_access(session, auth, body.search_space_id)
account.owner_search_space_id = body.search_space_id account.owner_search_space_id = body.search_space_id
account.updated_at = datetime.now(UTC) account.updated_at = datetime.now(UTC)
@ -1061,9 +1071,10 @@ async def update_gateway_account_search_space(
@router.delete("/bindings/{binding_id}") @router.delete("/bindings/{binding_id}")
async def delete_binding( async def delete_binding(
binding_id: int, binding_id: int,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]: ) -> dict[str, bool]:
user = auth.user
binding = await session.get(ExternalChatBinding, binding_id) binding = await session.get(ExternalChatBinding, binding_id)
if binding is None or binding.user_id != user.id: if binding is None or binding.user_id != user.id:
raise HTTPException(status_code=404, detail="Binding not found") raise HTTPException(status_code=404, detail="Binding not found")
@ -1078,9 +1089,10 @@ async def delete_binding(
@router.delete("/accounts/{account_id}") @router.delete("/accounts/{account_id}")
async def delete_gateway_account( async def delete_gateway_account(
account_id: int, account_id: int,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]: ) -> dict[str, bool]:
user = auth.user
account = await session.get(ExternalChatAccount, account_id) account = await session.get(ExternalChatAccount, account_id)
if ( if (
account is None account is None
@ -1114,9 +1126,10 @@ async def delete_gateway_account(
@router.post("/bindings/{binding_id}/resume") @router.post("/bindings/{binding_id}/resume")
async def resume_external_chat_binding( async def resume_external_chat_binding(
binding_id: int, binding_id: int,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, bool]: ) -> dict[str, bool]:
user = auth.user
binding = await session.get(ExternalChatBinding, binding_id) binding = await session.get(ExternalChatBinding, binding_id)
if binding is None or binding.user_id != user.id: if binding is None or binding.user_id != user.id:
raise HTTPException(status_code=404, detail="Binding not found") raise HTTPException(status_code=404, detail="Binding not found")

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.config import config from app.config import config
from app.db import ( from app.db import (
ExternalChatAccount, ExternalChatAccount,
@ -20,7 +21,7 @@ from app.db import (
get_async_session, get_async_session,
) )
from app.gateway.whatsapp.adapter_baileys import WhatsAppBaileysAdapter from app.gateway.whatsapp.adapter_baileys import WhatsAppBaileysAdapter
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_search_space_access from app.utils.rbac import check_search_space_access
router = APIRouter(prefix="/gateway/whatsapp/baileys", tags=["gateway"]) router = APIRouter(prefix="/gateway/whatsapp/baileys", tags=["gateway"])
@ -60,11 +61,12 @@ async def _get_user_whatsapp_account(
@router.post("/pair") @router.post("/pair")
async def request_pairing_code( async def request_pairing_code(
body: BaileysPairRequest, body: BaileysPairRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> dict[str, Any]: ) -> dict[str, Any]:
user = auth.user
_ensure_baileys_enabled() _ensure_baileys_enabled()
await check_search_space_access(session, user, body.search_space_id) await check_search_space_access(session, auth, body.search_space_id)
adapter = WhatsAppBaileysAdapter() adapter = WhatsAppBaileysAdapter()
try: try:
pairing = await adapter.request_pairing_code(phone_number=body.phone_number) pairing = await adapter.request_pairing_code(phone_number=body.phone_number)
@ -97,8 +99,9 @@ async def request_pairing_code(
@router.get("/health") @router.get("/health")
async def bridge_health( async def bridge_health(
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
) -> dict[str, Any]: ) -> dict[str, Any]:
user = auth.user
_ensure_baileys_enabled() _ensure_baileys_enabled()
adapter = WhatsAppBaileysAdapter() adapter = WhatsAppBaileysAdapter()
try: try:

View file

@ -16,6 +16,7 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.config import config from app.config import config
from app.db import ( from app.db import (
ImageGeneration, ImageGeneration,
@ -46,7 +47,7 @@ from app.services.image_gen_router_service import (
) )
from app.services.model_capabilities import has_capability from app.services.model_capabilities import has_capability
from app.services.model_resolver import to_litellm from app.services.model_resolver import to_litellm
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
from app.utils.signed_image_urls import verify_image_token from app.utils.signed_image_urls import verify_image_token
@ -231,8 +232,9 @@ async def _execute_image_generation(
async def create_image_generation( async def create_image_generation(
data: ImageGenerationCreate, data: ImageGenerationCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Create and execute an image generation request. """Create and execute an image generation request.
Premium configs are gated by the user's shared premium credit pool. Premium configs are gated by the user's shared premium credit pool.
@ -256,7 +258,7 @@ async def create_image_generation(
try: try:
await check_permission( await check_permission(
session, session,
user, auth,
data.search_space_id, data.search_space_id,
Permission.IMAGE_GENERATIONS_CREATE.value, Permission.IMAGE_GENERATIONS_CREATE.value,
"You don't have permission to create image generations in this search space", "You don't have permission to create image generations in this search space",
@ -351,8 +353,9 @@ async def list_image_generations(
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""List image generations.""" """List image generations."""
if skip < 0 or limit < 1: if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters") raise HTTPException(status_code=400, detail="Invalid pagination parameters")
@ -363,7 +366,7 @@ async def list_image_generations(
if search_space_id is not None: if search_space_id is not None:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.IMAGE_GENERATIONS_READ.value, Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space", "You don't have permission to read image generations in this search space",
@ -403,8 +406,9 @@ async def list_image_generations(
async def get_image_generation( async def get_image_generation(
image_gen_id: int, image_gen_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Get a specific image generation by ID.""" """Get a specific image generation by ID."""
try: try:
result = await session.execute( result = await session.execute(
@ -416,7 +420,7 @@ async def get_image_generation(
await check_permission( await check_permission(
session, session,
user, auth,
image_gen.search_space_id, image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_READ.value, Permission.IMAGE_GENERATIONS_READ.value,
"You don't have permission to read image generations in this search space", "You don't have permission to read image generations in this search space",
@ -435,8 +439,9 @@ async def get_image_generation(
async def delete_image_generation( async def delete_image_generation(
image_gen_id: int, image_gen_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Delete an image generation record.""" """Delete an image generation record."""
try: try:
result = await session.execute( result = await session.execute(
@ -448,7 +453,7 @@ async def delete_image_generation(
await check_permission( await check_permission(
session, session,
user, auth,
db_image_gen.search_space_id, db_image_gen.search_space_id,
Permission.IMAGE_GENERATIONS_DELETE.value, Permission.IMAGE_GENERATIONS_DELETE.value,
"You don't have permission to delete image generations in this search space", "You don't have permission to delete image generations in this search space",

View file

@ -5,6 +5,7 @@ from sqlalchemy import and_, desc
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import ( from app.db import (
Log, Log,
LogLevel, LogLevel,
@ -16,7 +17,7 @@ from app.db import (
get_async_session, get_async_session,
) )
from app.schemas import LogCreate, LogRead, LogUpdate from app.schemas import LogCreate, LogRead, LogUpdate
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
router = APIRouter() router = APIRouter()
@ -26,8 +27,9 @@ router = APIRouter()
async def create_log( async def create_log(
log: LogCreate, log: LogCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Create a new log entry. Create a new log entry.
Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated). Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated).
@ -36,7 +38,7 @@ async def create_log(
# Check if the user has access to the search space # Check if the user has access to the search space
await check_permission( await check_permission(
session, session,
user, auth,
log.search_space_id, log.search_space_id,
Permission.LOGS_READ.value, Permission.LOGS_READ.value,
"You don't have permission to access logs in this search space", "You don't have permission to access logs in this search space",
@ -67,8 +69,9 @@ async def read_logs(
start_date: datetime | None = None, start_date: datetime | None = None,
end_date: datetime | None = None, end_date: datetime | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get logs with optional filtering. Get logs with optional filtering.
Requires LOGS_READ permission for the search space(s). Requires LOGS_READ permission for the search space(s).
@ -81,7 +84,7 @@ async def read_logs(
# Check permission for specific search space # Check permission for specific search space
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.LOGS_READ.value, Permission.LOGS_READ.value,
"You don't have permission to read logs in this search space", "You don't have permission to read logs in this search space",
@ -136,8 +139,9 @@ async def read_logs(
async def read_log( async def read_log(
log_id: int, log_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a specific log by ID. Get a specific log by ID.
Requires LOGS_READ permission for the search space. Requires LOGS_READ permission for the search space.
@ -152,7 +156,7 @@ async def read_log(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
log.search_space_id, log.search_space_id,
Permission.LOGS_READ.value, Permission.LOGS_READ.value,
"You don't have permission to read logs in this search space", "You don't have permission to read logs in this search space",
@ -172,8 +176,9 @@ async def update_log(
log_id: int, log_id: int,
log_update: LogUpdate, log_update: LogUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Update a log entry. Update a log entry.
Requires LOGS_READ permission (logs are typically updated by system). Requires LOGS_READ permission (logs are typically updated by system).
@ -188,7 +193,7 @@ async def update_log(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
db_log.search_space_id, db_log.search_space_id,
Permission.LOGS_READ.value, Permission.LOGS_READ.value,
"You don't have permission to access logs in this search space", "You don't have permission to access logs in this search space",
@ -215,8 +220,9 @@ async def update_log(
async def delete_log( async def delete_log(
log_id: int, log_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete a log entry. Delete a log entry.
Requires LOGS_DELETE permission for the search space. Requires LOGS_DELETE permission for the search space.
@ -231,7 +237,7 @@ async def delete_log(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
db_log.search_space_id, db_log.search_space_id,
Permission.LOGS_DELETE.value, Permission.LOGS_DELETE.value,
"You don't have permission to delete logs in this search space", "You don't have permission to delete logs in this search space",
@ -254,8 +260,9 @@ async def get_logs_summary(
search_space_id: int, search_space_id: int,
hours: int = 24, hours: int = 24,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a summary of logs for a search space in the last X hours. Get a summary of logs for a search space in the last X hours.
Requires LOGS_READ permission for the search space. Requires LOGS_READ permission for the search space.
@ -264,7 +271,7 @@ async def get_logs_summary(
# Check permission # Check permission
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.LOGS_READ.value, Permission.LOGS_READ.value,
"You don't have permission to read logs in this search space", "You don't have permission to read logs in this search space",

View file

@ -10,8 +10,9 @@ from fastapi.responses import Response
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import NewChatThread, Permission, User, get_async_session from app.db import NewChatThread, Permission, User, get_async_session
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,8 +48,9 @@ async def download_sandbox_file(
thread_id: int, thread_id: int,
path: str = Query(..., description="Absolute path of the file inside the sandbox"), path: str = Query(..., description="Absolute path of the file inside the sandbox"),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Download a file from the Daytona sandbox associated with a chat thread.""" """Download a file from the Daytona sandbox associated with a chat thread."""
from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import ( from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox import (
@ -68,7 +70,7 @@ async def download_sandbox_file(
await check_permission( await check_permission(
session, session,
user, auth,
thread.search_space_id, thread.search_space_id,
Permission.CHATS_READ.value, Permission.CHATS_READ.value,
"You don't have permission to access files in this thread", "You don't have permission to access files in this thread",

View file

@ -16,6 +16,7 @@ from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ( from app.db import (
Permission, Permission,
SearchSpace, SearchSpace,
@ -25,7 +26,7 @@ from app.db import (
get_async_session, get_async_session,
) )
from app.schemas import VideoPresentationRead from app.schemas import VideoPresentationRead
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
router = APIRouter() router = APIRouter()
@ -37,8 +38,9 @@ async def read_video_presentations(
limit: int = 100, limit: int = 100,
search_space_id: int | None = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
List video presentations the user has access to. List video presentations the user has access to.
Requires VIDEO_PRESENTATIONS_READ permission for the search space(s). Requires VIDEO_PRESENTATIONS_READ permission for the search space(s).
@ -49,7 +51,7 @@ async def read_video_presentations(
if search_space_id is not None: if search_space_id is not None:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.VIDEO_PRESENTATIONS_READ.value, Permission.VIDEO_PRESENTATIONS_READ.value,
"You don't have permission to read video presentations in this search space", "You don't have permission to read video presentations in this search space",
@ -89,8 +91,9 @@ async def read_video_presentations(
async def read_video_presentation( async def read_video_presentation(
video_presentation_id: int, video_presentation_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a specific video presentation by ID. Get a specific video presentation by ID.
Requires authentication with VIDEO_PRESENTATIONS_READ permission. Requires authentication with VIDEO_PRESENTATIONS_READ permission.
@ -112,7 +115,7 @@ async def read_video_presentation(
await check_permission( await check_permission(
session, session,
user, auth,
video_pres.search_space_id, video_pres.search_space_id,
Permission.VIDEO_PRESENTATIONS_READ.value, Permission.VIDEO_PRESENTATIONS_READ.value,
"You don't have permission to read video presentations in this search space", "You don't have permission to read video presentations in this search space",
@ -132,8 +135,9 @@ async def read_video_presentation(
async def delete_video_presentation( async def delete_video_presentation(
video_presentation_id: int, video_presentation_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete a video presentation. Delete a video presentation.
Requires VIDEO_PRESENTATIONS_DELETE permission for the search space. Requires VIDEO_PRESENTATIONS_DELETE permission for the search space.
@ -151,7 +155,7 @@ async def delete_video_presentation(
await check_permission( await check_permission(
session, session,
user, auth,
db_video_pres.search_space_id, db_video_pres.search_space_id,
Permission.VIDEO_PRESENTATIONS_DELETE.value, Permission.VIDEO_PRESENTATIONS_DELETE.value,
"You don't have permission to delete video presentations in this search space", "You don't have permission to delete video presentations in this search space",
@ -175,8 +179,9 @@ async def stream_slide_audio(
video_presentation_id: int, video_presentation_id: int,
slide_number: int, slide_number: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Stream the audio file for a specific slide in a video presentation. Stream the audio file for a specific slide in a video presentation.
The slide_number is 1-based. Audio path is read from the slides JSONB. The slide_number is 1-based. Audio path is read from the slides JSONB.
@ -194,7 +199,7 @@ async def stream_slide_audio(
await check_permission( await check_permission(
session, session,
user, auth,
video_pres.search_space_id, video_pres.search_space_id,
Permission.VIDEO_PRESENTATIONS_READ.value, Permission.VIDEO_PRESENTATIONS_READ.value,
"You don't have permission to access video presentations in this search space", "You don't have permission to access video presentations in this search space",