feat: enforce API access for knowledge resources

This commit is contained in:
Anish Sarkar 2026-06-19 20:27:47 +05:30
parent 7e8d26fa81
commit 493e8d5a64
8 changed files with 206 additions and 130 deletions

View file

@ -9,6 +9,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Document, Permission, User, get_async_session from app.db import Document, Permission, User, get_async_session
from app.file_storage.persistence.enums import DocumentFileKind from app.file_storage.persistence.enums import DocumentFileKind
from app.file_storage.schemas import DocumentFileRead from app.file_storage.schemas import DocumentFileRead
@ -17,7 +18,7 @@ from app.file_storage.service import (
list_document_files, list_document_files,
open_document_file_stream, open_document_file_stream,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
router = APIRouter() router = APIRouter()
@ -35,7 +36,7 @@ async def _load_readable_document(
await check_permission( await check_permission(
session, session,
user, auth,
document.search_space_id, document.search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -57,8 +58,9 @@ def _content_disposition(filename: str) -> str:
async def read_document_files( async def read_document_files(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
) -> list[DocumentFileRead]: ) -> list[DocumentFileRead]:
user = auth.user
"""Return metadata for every stored file of a document (gates the UI).""" """Return metadata for every stored file of a document (gates the UI)."""
await _load_readable_document(document_id=document_id, session=session, user=user) await _load_readable_document(document_id=document_id, session=session, user=user)
records = await list_document_files(session, document_id=document_id) records = await list_document_files(session, document_id=document_id)
@ -69,8 +71,9 @@ async def read_document_files(
async def download_original_document_file( async def download_original_document_file(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
) -> StreamingResponse: ) -> StreamingResponse:
user = auth.user
"""Stream the document's original uploaded file.""" """Stream the document's original uploaded file."""
await _load_readable_document(document_id=document_id, session=session, user=user) await _load_readable_document(document_id=document_id, session=session, user=user)

View file

@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.agents.chat.runtime.path_resolver import virtual_path_to_doc from app.agents.chat.runtime.path_resolver import virtual_path_to_doc
from app.db import ( from app.db import (
Chunk, Chunk,
@ -35,7 +36,7 @@ from app.schemas import (
PaginatedResponse, PaginatedResponse,
) )
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
try: try:
@ -60,8 +61,9 @@ MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB per file
async def create_documents( async def create_documents(
request: DocumentsCreate, request: DocumentsCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Create new documents. Create new documents.
Requires DOCUMENTS_CREATE permission. Requires DOCUMENTS_CREATE permission.
@ -70,7 +72,7 @@ async def create_documents(
# Check permission # Check permission
await check_permission( await check_permission(
session, session,
user, auth,
request.search_space_id, request.search_space_id,
Permission.DOCUMENTS_CREATE.value, Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space", "You don't have permission to create documents in this search space",
@ -128,9 +130,10 @@ async def create_documents_file_upload(
use_vision_llm: bool = Form(False), use_vision_llm: bool = Form(False),
processing_mode: str = Form("basic"), processing_mode: str = Form("basic"),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
dispatcher: TaskDispatcher = Depends(get_task_dispatcher), dispatcher: TaskDispatcher = Depends(get_task_dispatcher),
): ):
user = auth.user
""" """
Upload files as documents with real-time status tracking. Upload files as documents with real-time status tracking.
@ -159,7 +162,7 @@ async def create_documents_file_upload(
try: try:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_CREATE.value, Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space", "You don't have permission to create documents in this search space",
@ -340,8 +343,9 @@ async def read_documents(
sort_by: str = "created_at", sort_by: str = "created_at",
sort_order: str = "desc", sort_order: str = "desc",
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
List documents the user has access to, with optional filtering and pagination. List documents the user has access to, with optional filtering and pagination.
Requires DOCUMENTS_READ permission for the search space(s). Requires DOCUMENTS_READ permission for the search space(s).
@ -369,7 +373,7 @@ async def read_documents(
if search_space_id is not None: if search_space_id is not None:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -519,8 +523,9 @@ async def search_documents(
search_space_id: int | None = None, search_space_id: int | None = None,
document_types: str | None = None, document_types: str | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Search documents by title substring, optionally filtered by search_space_id and document_types. Search documents by title substring, optionally filtered by search_space_id and document_types.
Requires DOCUMENTS_READ permission for the search space(s). Requires DOCUMENTS_READ permission for the search space(s).
@ -549,7 +554,7 @@ async def search_documents(
if search_space_id is not None: if search_space_id is not None:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -677,8 +682,9 @@ async def search_document_titles(
page: int = 0, page: int = 0,
page_size: int = 20, page_size: int = 20,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Lightweight document title search optimized for mention picker (@mentions). Lightweight document title search optimized for mention picker (@mentions).
@ -703,7 +709,7 @@ async def search_document_titles(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -781,8 +787,9 @@ async def get_document_by_virtual_path(
search_space_id: int, search_space_id: int,
virtual_path: str, virtual_path: str,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Resolve a knowledge-base document by its agent-facing virtual path. """Resolve a knowledge-base document by its agent-facing virtual path.
The agent renders every document under ``/documents/...`` with a The agent renders every document under ``/documents/...`` with a
@ -804,7 +811,7 @@ async def get_document_by_virtual_path(
try: try:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -838,8 +845,9 @@ async def get_documents_status(
search_space_id: int, search_space_id: int,
document_ids: str, document_ids: str,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Batch status endpoint for documents in a search space. Batch status endpoint for documents in a search space.
@ -849,7 +857,7 @@ async def get_documents_status(
try: try:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -905,8 +913,9 @@ async def get_documents_status(
async def get_document_type_counts( async def get_document_type_counts(
search_space_id: int | None = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get counts of documents by type for search spaces the user has access to. Get counts of documents by type for search spaces the user has access to.
Requires DOCUMENTS_READ permission for the search space(s). Requires DOCUMENTS_READ permission for the search space(s).
@ -926,7 +935,7 @@ async def get_document_type_counts(
# Check permission for specific search space # Check permission for specific search space
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -965,8 +974,9 @@ async def get_document_by_chunk_id(
5, ge=0, description="Number of chunks before/after the cited chunk to include" 5, ge=0, description="Number of chunks before/after the cited chunk to include"
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Retrieves a document based on a chunk ID, including a window of chunks around the cited one. Retrieves a document based on a chunk ID, including a window of chunks around the cited one.
Uses SQL-level pagination to avoid loading all chunks into memory. Uses SQL-level pagination to avoid loading all chunks into memory.
@ -995,7 +1005,7 @@ async def get_document_by_chunk_id(
await check_permission( await check_permission(
session, session,
user, auth,
document.search_space_id, document.search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -1060,12 +1070,13 @@ async def get_document_by_chunk_id(
async def get_watched_folders( async def get_watched_folders(
search_space_id: int, search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Return root folders that are marked as watched (metadata->>'watched' = 'true').""" """Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -1101,8 +1112,9 @@ async def get_document_chunks_paginated(
None, ge=0, description="Direct offset; overrides page * page_size" None, ge=0, description="Direct offset; overrides page * page_size"
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Paginated chunk loading for a document. Paginated chunk loading for a document.
Supports both page-based and offset-based access. Supports both page-based and offset-based access.
@ -1120,7 +1132,7 @@ async def get_document_chunks_paginated(
await check_permission( await check_permission(
session, session,
user, auth,
document.search_space_id, document.search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -1162,8 +1174,9 @@ async def get_document_chunks_paginated(
async def read_document( async def read_document(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a specific document by ID. Get a specific document by ID.
Requires DOCUMENTS_READ permission for the search space. Requires DOCUMENTS_READ permission for the search space.
@ -1182,7 +1195,7 @@ async def read_document(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
document.search_space_id, document.search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -1216,8 +1229,9 @@ async def update_document(
document_id: int, document_id: int,
document_update: DocumentUpdate, document_update: DocumentUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Update a document. Update a document.
Requires DOCUMENTS_UPDATE permission for the search space. Requires DOCUMENTS_UPDATE permission for the search space.
@ -1236,7 +1250,7 @@ async def update_document(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
db_document.search_space_id, db_document.search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update documents in this search space", "You don't have permission to update documents in this search space",
@ -1275,8 +1289,9 @@ async def update_document(
async def delete_document( async def delete_document(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete a document. Delete a document.
Requires DOCUMENTS_DELETE permission for the search space. Requires DOCUMENTS_DELETE permission for the search space.
@ -1311,7 +1326,7 @@ async def delete_document(
# Check permission for the search space # Check permission for the search space
await check_permission( await check_permission(
session, session,
user, auth,
document.search_space_id, document.search_space_id,
Permission.DOCUMENTS_DELETE.value, Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete documents in this search space", "You don't have permission to delete documents in this search space",
@ -1355,8 +1370,9 @@ async def delete_document(
async def list_document_versions( async def list_document_versions(
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""List all versions for a document, ordered by version_number descending.""" """List all versions for a document, ordered by version_number descending."""
document = ( document = (
await session.execute(select(Document).where(Document.id == document_id)) await session.execute(select(Document).where(Document.id == document_id))
@ -1396,8 +1412,9 @@ async def get_document_version(
document_id: int, document_id: int,
version_number: int, version_number: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Get full version content including source_markdown.""" """Get full version content including source_markdown."""
document = ( document = (
await session.execute(select(Document).where(Document.id == document_id)) await session.execute(select(Document).where(Document.id == document_id))
@ -1434,8 +1451,9 @@ async def restore_document_version(
document_id: int, document_id: int,
version_number: int, version_number: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Restore a previous version: snapshot current state, then overwrite document content.""" """Restore a previous version: snapshot current state, then overwrite document content."""
document = ( document = (
await session.execute(select(Document).where(Document.id == document_id)) await session.execute(select(Document).where(Document.id == document_id))
@ -1517,8 +1535,9 @@ class FolderSyncFinalizeRequest(PydanticBaseModel):
async def folder_mtime_check( async def folder_mtime_check(
request: FolderMtimeCheckRequest, request: FolderMtimeCheckRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Pre-upload optimization: check which files need uploading based on mtime. """Pre-upload optimization: check which files need uploading based on mtime.
Returns the subset of relative paths where the file is new or has a Returns the subset of relative paths where the file is new or has a
@ -1528,7 +1547,7 @@ async def folder_mtime_check(
await check_permission( await check_permission(
session, session,
user, auth,
request.search_space_id, request.search_space_id,
Permission.DOCUMENTS_CREATE.value, Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space", "You don't have permission to create documents in this search space",
@ -1587,8 +1606,9 @@ async def folder_upload(
use_vision_llm: bool = Form(False), use_vision_llm: bool = Form(False),
processing_mode: str = Form("basic"), processing_mode: str = Form("basic"),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Upload files from the desktop app for folder indexing. """Upload files from the desktop app for folder indexing.
Files are written to temp storage and dispatched to a Celery task. Files are written to temp storage and dispatched to a Celery task.
@ -1603,7 +1623,7 @@ async def folder_upload(
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_CREATE.value, Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create documents in this search space", "You don't have permission to create documents in this search space",
@ -1733,8 +1753,9 @@ async def folder_upload(
async def folder_unlink( async def folder_unlink(
request: FolderUnlinkRequest, request: FolderUnlinkRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Handle file deletion events from the desktop watcher. """Handle file deletion events from the desktop watcher.
For each relative path, find the matching document and delete it. For each relative path, find the matching document and delete it.
@ -1746,7 +1767,7 @@ async def folder_unlink(
await check_permission( await check_permission(
session, session,
user, auth,
request.search_space_id, request.search_space_id,
Permission.DOCUMENTS_DELETE.value, Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete documents in this search space", "You don't have permission to delete documents in this search space",
@ -1787,8 +1808,9 @@ async def folder_unlink(
async def folder_sync_finalize( async def folder_sync_finalize(
request: FolderSyncFinalizeRequest, request: FolderSyncFinalizeRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Finalize a full folder scan by deleting orphaned documents. """Finalize a full folder scan by deleting orphaned documents.
The client sends the complete list of relative paths currently in the The client sends the complete list of relative paths currently in the
@ -1803,7 +1825,7 @@ async def folder_sync_finalize(
await check_permission( await check_permission(
session, session,
user, auth,
request.search_space_id, request.search_space_id,
Permission.DOCUMENTS_DELETE.value, Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete documents in this search space", "You don't have permission to delete documents in this search space",

View file

@ -18,6 +18,7 @@ from fastapi.responses import StreamingResponse
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session
from app.routes.reports_routes import ( from app.routes.reports_routes import (
_FILE_EXTENSIONS, _FILE_EXTENSIONS,
@ -31,7 +32,7 @@ from app.templates.export_helpers import (
get_reference_docx_path, get_reference_docx_path,
get_typst_template_path, get_typst_template_path,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,8 +48,9 @@ async def get_editor_content(
search_space_id: int, search_space_id: int,
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get document content for editing. Get document content for editing.
@ -60,7 +62,7 @@ async def get_editor_content(
# Check RBAC permission # Check RBAC permission
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -178,15 +180,16 @@ async def download_document_markdown(
search_space_id: int, search_space_id: int,
document_id: int, document_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Download the full document content as a .md file. Download the full document content as a .md file.
Reconstructs markdown from source_markdown or chunks. Reconstructs markdown from source_markdown or chunks.
""" """
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",
@ -244,8 +247,9 @@ async def save_document(
document_id: int, document_id: int,
data: dict[str, Any], data: dict[str, Any],
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Save document markdown and trigger reindexing. Save document markdown and trigger reindexing.
Called when user clicks 'Save & Exit'. Called when user clicks 'Save & Exit'.
@ -259,7 +263,7 @@ async def save_document(
# Check RBAC permission # Check RBAC permission
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update documents in this search space", "You don't have permission to update documents in this search space",
@ -331,12 +335,13 @@ async def export_document(
description="Export format: pdf, docx, html, latex, epub, odt, or plain", description="Export format: pdf, docx, html, latex, epub, odt, or plain",
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Export a document in the requested format (reuses the report export pipeline).""" """Export a document in the requested format (reuses the report export pipeline)."""
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space", "You don't have permission to read documents in this search space",

View file

@ -5,6 +5,7 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import Document, Folder, Permission, User, get_async_session from app.db import Document, Folder, Permission, User, get_async_session
from app.schemas import ( from app.schemas import (
BulkDocumentMove, BulkDocumentMove,
@ -23,7 +24,7 @@ from app.services.folder_service import (
get_subtree_max_depth, get_subtree_max_depth,
validate_folder_depth, validate_folder_depth,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
router = APIRouter() router = APIRouter()
@ -33,13 +34,14 @@ router = APIRouter()
async def create_folder( async def create_folder(
request: FolderCreate, request: FolderCreate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Create a new folder. Requires DOCUMENTS_CREATE permission.""" """Create a new folder. Requires DOCUMENTS_CREATE permission."""
try: try:
await check_permission( await check_permission(
session, session,
user, auth,
request.search_space_id, request.search_space_id,
Permission.DOCUMENTS_CREATE.value, Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create folders in this search space", "You don't have permission to create folders in this search space",
@ -91,13 +93,14 @@ async def create_folder(
async def list_folders( async def list_folders(
search_space_id: int, search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""List all folders in a search space (flat). Requires DOCUMENTS_READ permission.""" """List all folders in a search space (flat). Requires DOCUMENTS_READ permission."""
try: try:
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space", "You don't have permission to read folders in this search space",
@ -122,8 +125,9 @@ async def list_folders(
async def get_folder( async def get_folder(
folder_id: int, folder_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Get a single folder. Requires DOCUMENTS_READ permission.""" """Get a single folder. Requires DOCUMENTS_READ permission."""
try: try:
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
@ -132,7 +136,7 @@ async def get_folder(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space", "You don't have permission to read folders in this search space",
@ -152,8 +156,9 @@ async def get_folder(
async def get_folder_breadcrumb( async def get_folder_breadcrumb(
folder_id: int, folder_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission.""" """Get ancestor chain for breadcrumb display. Requires DOCUMENTS_READ permission."""
try: try:
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
@ -162,7 +167,7 @@ async def get_folder_breadcrumb(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read folders in this search space", "You don't have permission to read folders in this search space",
@ -196,8 +201,9 @@ async def get_folder_breadcrumb(
async def stop_watching_folder( async def stop_watching_folder(
folder_id: int, folder_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Clear the watched flag from a folder's metadata.""" """Clear the watched flag from a folder's metadata."""
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
if not folder: if not folder:
@ -205,7 +211,7 @@ async def stop_watching_folder(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update folders in this search space", "You don't have permission to update folders in this search space",
@ -224,8 +230,9 @@ async def update_folder(
folder_id: int, folder_id: int,
request: FolderUpdate, request: FolderUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Rename a folder. Requires DOCUMENTS_UPDATE permission.""" """Rename a folder. Requires DOCUMENTS_UPDATE permission."""
try: try:
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
@ -234,7 +241,7 @@ async def update_folder(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to update folders in this search space", "You don't have permission to update folders in this search space",
@ -264,8 +271,9 @@ async def move_folder(
folder_id: int, folder_id: int,
request: FolderMove, request: FolderMove,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission.""" """Move a folder to a new parent. Requires DOCUMENTS_UPDATE permission."""
try: try:
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
@ -274,7 +282,7 @@ async def move_folder(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move folders in this search space", "You don't have permission to move folders in this search space",
@ -324,8 +332,9 @@ async def reorder_folder(
folder_id: int, folder_id: int,
request: FolderReorder, request: FolderReorder,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE.""" """Reorder a folder among its siblings via fractional indexing. Requires DOCUMENTS_UPDATE."""
try: try:
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
@ -334,7 +343,7 @@ async def reorder_folder(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to reorder folders in this search space", "You don't have permission to reorder folders in this search space",
@ -365,8 +374,9 @@ async def reorder_folder(
async def delete_folder( async def delete_folder(
folder_id: int, folder_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders.""" """Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
try: try:
folder = await session.get(Folder, folder_id) folder = await session.get(Folder, folder_id)
@ -375,7 +385,7 @@ async def delete_folder(
await check_permission( await check_permission(
session, session,
user, auth,
folder.search_space_id, folder.search_space_id,
Permission.DOCUMENTS_DELETE.value, Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete folders in this search space", "You don't have permission to delete folders in this search space",
@ -439,8 +449,9 @@ async def move_document(
document_id: int, document_id: int,
request: DocumentMove, request: DocumentMove,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission.""" """Move a document to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try: try:
result = await session.execute( result = await session.execute(
@ -452,7 +463,7 @@ async def move_document(
await check_permission( await check_permission(
session, session,
user, auth,
document.search_space_id, document.search_space_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move documents in this search space", "You don't have permission to move documents in this search space",
@ -485,8 +496,9 @@ async def move_document(
async def bulk_move_documents( async def bulk_move_documents(
request: BulkDocumentMove, request: BulkDocumentMove,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission.""" """Move multiple documents to a folder (or root). Requires DOCUMENTS_UPDATE permission."""
try: try:
if not request.document_ids: if not request.document_ids:
@ -504,7 +516,7 @@ async def bulk_move_documents(
for ss_id in search_space_ids: for ss_id in search_space_ids:
await check_permission( await check_permission(
session, session,
user, auth,
ss_id, ss_id,
Permission.DOCUMENTS_UPDATE.value, Permission.DOCUMENTS_UPDATE.value,
"You don't have permission to move documents in this search space", "You don't have permission to move documents in this search space",

View file

@ -9,9 +9,10 @@ from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import Document, DocumentType, Permission, User, get_async_session from app.db import Document, DocumentType, Permission, User, get_async_session
from app.schemas import DocumentRead, PaginatedResponse from app.schemas import DocumentRead, PaginatedResponse
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_permission from app.utils.rbac import check_permission
router = APIRouter() router = APIRouter()
@ -27,8 +28,9 @@ async def create_note(
search_space_id: int, search_space_id: int,
request: CreateNoteRequest, request: CreateNoteRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Create a new note document. Create a new note document.
@ -37,7 +39,7 @@ async def create_note(
# Check RBAC permission # Check RBAC permission
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_CREATE.value, Permission.DOCUMENTS_CREATE.value,
"You don't have permission to create notes in this search space", "You don't have permission to create notes in this search space",
@ -98,8 +100,9 @@ async def list_notes(
page: int | None = None, page: int | None = None,
page_size: int = 50, page_size: int = 50,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
List all notes in a search space. List all notes in a search space.
@ -108,7 +111,7 @@ async def list_notes(
# Check RBAC permission # Check RBAC permission
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_READ.value,
"You don't have permission to read notes in this search space", "You don't have permission to read notes in this search space",
@ -191,8 +194,9 @@ async def delete_note(
search_space_id: int, search_space_id: int,
note_id: int, note_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete a note. Delete a note.
@ -201,7 +205,7 @@ async def delete_note(
# Check RBAC permission # Check RBAC permission
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.DOCUMENTS_DELETE.value, Permission.DOCUMENTS_DELETE.value,
"You don't have permission to delete notes in this search space", "You don't have permission to delete notes in this search space",

View file

@ -28,6 +28,7 @@ from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ( from app.db import (
Report, Report,
SearchSpace, SearchSpace,
@ -42,7 +43,7 @@ from app.templates.export_helpers import (
get_reference_docx_path, get_reference_docx_path,
get_typst_template_path, get_typst_template_path,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_search_space_access from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -158,8 +159,9 @@ def _normalize_latex_delimiters(text: str) -> str:
async def _get_report_with_access( async def _get_report_with_access(
report_id: int, report_id: int,
session: AsyncSession, session: AsyncSession,
user: User, auth: AuthContext,
) -> Report: ) -> Report:
user = auth.user
"""Fetch a report and verify the user belongs to its search space. """Fetch a report and verify the user belongs to its search space.
Raises HTTPException(404) if not found, HTTPException(403) if no access. Raises HTTPException(404) if not found, HTTPException(403) if no access.
@ -172,7 +174,7 @@ async def _get_report_with_access(
# Lightweight membership check - no granular RBAC, just "is the user a # Lightweight membership check - no granular RBAC, just "is the user a
# member of the search space this report belongs to?" # member of the search space this report belongs to?"
await check_search_space_access(session, user, report.search_space_id) await check_search_space_access(session, auth, report.search_space_id)
return report return report
@ -206,8 +208,9 @@ async def read_reports(
limit: int = Query(default=100, ge=1, le=MAX_REPORT_LIST_LIMIT), limit: int = Query(default=100, ge=1, le=MAX_REPORT_LIST_LIMIT),
search_space_id: int | None = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
List reports the user has access to. List reports the user has access to.
Filters by search space membership. Filters by search space membership.
@ -215,7 +218,7 @@ async def read_reports(
try: try:
if search_space_id is not None: if search_space_id is not None:
# Verify the caller is a member of the requested search space # Verify the caller is a member of the requested search space
await check_search_space_access(session, user, search_space_id) await check_search_space_access(session, auth, search_space_id)
result = await session.execute( result = await session.execute(
select(Report) select(Report)
@ -247,8 +250,9 @@ async def read_reports(
async def read_report( async def read_report(
report_id: int, report_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a specific report by ID (metadata only, no content). Get a specific report by ID (metadata only, no content).
""" """
@ -266,8 +270,9 @@ async def read_report(
async def read_report_content( async def read_report_content(
report_id: int, report_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get full Markdown content of a report, including version siblings. Get full Markdown content of a report, including version siblings.
""" """
@ -298,8 +303,9 @@ async def update_report_content(
report_id: int, report_id: int,
body: ReportContentUpdate, body: ReportContentUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Update the Markdown content of a report. Update the Markdown content of a report.
@ -339,8 +345,9 @@ async def update_report_content(
async def preview_report_pdf( async def preview_report_pdf(
report_id: int, report_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Return a compiled PDF preview for Typst-based reports (resumes). Return a compiled PDF preview for Typst-based reports (resumes).
@ -394,8 +401,9 @@ async def export_report(
description="Export format: pdf, docx, html, latex, epub, odt, or plain", description="Export format: pdf, docx, html, latex, epub, odt, or plain",
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Export a report in the requested format. Export a report in the requested format.
""" """
@ -568,8 +576,9 @@ async def export_report(
async def delete_report( async def delete_report(
report_id: int, report_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete a report. Delete a report.
""" """

View file

@ -33,6 +33,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.config import config from app.config import config
from app.connectors.github_connector import GitHubConnector from app.connectors.github_connector import GitHubConnector
from app.db import ( from app.db import (
@ -56,7 +57,7 @@ from app.schemas import (
SearchSourceConnectorUpdate, SearchSourceConnectorUpdate,
) )
from app.services.composio_service import ComposioService, get_composio_service from app.services.composio_service import ComposioService, get_composio_service
from app.users import current_active_user from app.users import get_auth_context
# NOTE: connector indexer functions are imported lazily inside each # NOTE: connector indexer functions are imported lazily inside each
# ``run_*_indexing`` helper to break a circular import cycle: # ``run_*_indexing`` helper to break a circular import cycle:
@ -143,8 +144,9 @@ class GitHubPATRequest(BaseModel):
@router.post("/github/repositories", response_model=list[dict[str, Any]]) @router.post("/github/repositories", response_model=list[dict[str, Any]])
async def list_github_repositories( async def list_github_repositories(
pat_request: GitHubPATRequest, pat_request: GitHubPATRequest,
user: User = Depends(current_active_user), # Ensure the user is logged in auth: AuthContext = Depends(get_auth_context), # Ensure the user is logged in
): ):
user = auth.user
""" """
Fetches a list of repositories accessible by the provided GitHub PAT. Fetches a list of repositories accessible by the provided GitHub PAT.
The PAT is used for this request only and is not stored. The PAT is used for this request only and is not stored.
@ -173,8 +175,9 @@ async def create_search_source_connector(
..., description="ID of the search space to associate the connector with" ..., description="ID of the search space to associate the connector with"
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Create a new search source connector. Create a new search source connector.
Requires CONNECTORS_CREATE permission. Requires CONNECTORS_CREATE permission.
@ -186,7 +189,7 @@ async def create_search_source_connector(
# Check if user has permission to create connectors # Check if user has permission to create connectors
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.CONNECTORS_CREATE.value, Permission.CONNECTORS_CREATE.value,
"You don't have permission to create connectors in this search space", "You don't have permission to create connectors in this search space",
@ -281,8 +284,9 @@ async def read_search_source_connectors(
limit: int = 100, limit: int = 100,
search_space_id: int | None = None, search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
List all search source connectors for a search space. List all search source connectors for a search space.
Requires CONNECTORS_READ permission. Requires CONNECTORS_READ permission.
@ -297,7 +301,7 @@ async def read_search_source_connectors(
# Check if user has permission to read connectors # Check if user has permission to read connectors
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.CONNECTORS_READ.value, Permission.CONNECTORS_READ.value,
"You don't have permission to view connectors in this search space", "You don't have permission to view connectors in this search space",
@ -324,8 +328,9 @@ async def read_search_source_connectors(
async def read_search_source_connector( async def read_search_source_connector(
connector_id: int, connector_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a specific search source connector by ID. Get a specific search source connector by ID.
Requires CONNECTORS_READ permission. Requires CONNECTORS_READ permission.
@ -345,7 +350,7 @@ async def read_search_source_connector(
# Check permission # Check permission
await check_permission( await check_permission(
session, session,
user, auth,
connector.search_space_id, connector.search_space_id,
Permission.CONNECTORS_READ.value, Permission.CONNECTORS_READ.value,
"You don't have permission to view this connector", "You don't have permission to view this connector",
@ -367,8 +372,9 @@ async def update_search_source_connector(
connector_id: int, connector_id: int,
connector_update: SearchSourceConnectorUpdate, connector_update: SearchSourceConnectorUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Update a search source connector. Update a search source connector.
Requires CONNECTORS_UPDATE permission. Requires CONNECTORS_UPDATE permission.
@ -386,7 +392,7 @@ async def update_search_source_connector(
# Check permission # Check permission
await check_permission( await check_permission(
session, session,
user, auth,
db_connector.search_space_id, db_connector.search_space_id,
Permission.CONNECTORS_UPDATE.value, Permission.CONNECTORS_UPDATE.value,
"You don't have permission to update this connector", "You don't have permission to update this connector",
@ -557,8 +563,9 @@ async def update_search_source_connector(
async def delete_search_source_connector( async def delete_search_source_connector(
connector_id: int, connector_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete a search source connector and all its associated documents. Delete a search source connector and all its associated documents.
@ -588,7 +595,7 @@ async def delete_search_source_connector(
# Check permission # Check permission
await check_permission( await check_permission(
session, session,
user, auth,
db_connector.search_space_id, db_connector.search_space_id,
Permission.CONNECTORS_DELETE.value, Permission.CONNECTORS_DELETE.value,
"You don't have permission to delete this connector", "You don't have permission to delete this connector",
@ -725,8 +732,9 @@ async def index_connector_content(
description="[Google Drive only] Structured request with folders and files to index", description="[Google Drive only] Structured request with folders and files to index",
), ),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Index content from a KB connector to a search space. Index content from a KB connector to a search space.
@ -760,7 +768,7 @@ async def index_connector_content(
# the read/update/delete handlers — not the client-supplied query param. # the read/update/delete handlers — not the client-supplied query param.
await check_permission( await check_permission(
session, session,
user, auth,
connector.search_space_id, connector.search_space_id,
Permission.CONNECTORS_UPDATE.value, Permission.CONNECTORS_UPDATE.value,
"You don't have permission to index content in this search space", "You don't have permission to index content in this search space",
@ -2645,8 +2653,9 @@ async def create_mcp_connector(
connector_data: MCPConnectorCreate, connector_data: MCPConnectorCreate,
search_space_id: int = Query(..., description="Search space ID"), search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Create a new MCP (Model Context Protocol) connector. Create a new MCP (Model Context Protocol) connector.
@ -2669,7 +2678,7 @@ async def create_mcp_connector(
# Check user has permission to create connectors # Check user has permission to create connectors
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.CONNECTORS_CREATE.value, Permission.CONNECTORS_CREATE.value,
"You don't have permission to create connectors in this search space", "You don't have permission to create connectors in this search space",
@ -2724,8 +2733,9 @@ async def create_mcp_connector(
async def list_mcp_connectors( async def list_mcp_connectors(
search_space_id: int = Query(..., description="Search space ID"), search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
List all MCP connectors for a search space. List all MCP connectors for a search space.
@ -2741,7 +2751,7 @@ async def list_mcp_connectors(
# Check user has permission to read connectors # Check user has permission to read connectors
await check_permission( await check_permission(
session, session,
user, auth,
search_space_id, search_space_id,
Permission.CONNECTORS_READ.value, Permission.CONNECTORS_READ.value,
"You don't have permission to view connectors in this search space", "You don't have permission to view connectors in this search space",
@ -2775,8 +2785,9 @@ async def list_mcp_connectors(
async def get_mcp_connector( async def get_mcp_connector(
connector_id: int, connector_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Get a specific MCP connector by ID. Get a specific MCP connector by ID.
@ -2805,7 +2816,7 @@ async def get_mcp_connector(
# Check user has permission to read connectors # Check user has permission to read connectors
await check_permission( await check_permission(
session, session,
user, auth,
connector.search_space_id, connector.search_space_id,
Permission.CONNECTORS_READ.value, Permission.CONNECTORS_READ.value,
"You don't have permission to view this connector", "You don't have permission to view this connector",
@ -2828,8 +2839,9 @@ async def update_mcp_connector(
connector_id: int, connector_id: int,
connector_update: MCPConnectorUpdate, connector_update: MCPConnectorUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Update an MCP connector. Update an MCP connector.
@ -2859,7 +2871,7 @@ async def update_mcp_connector(
# Check user has permission to update connectors # Check user has permission to update connectors
await check_permission( await check_permission(
session, session,
user, auth,
connector.search_space_id, connector.search_space_id,
Permission.CONNECTORS_UPDATE.value, Permission.CONNECTORS_UPDATE.value,
"You don't have permission to update this connector", "You don't have permission to update this connector",
@ -2904,8 +2916,9 @@ async def update_mcp_connector(
async def delete_mcp_connector( async def delete_mcp_connector(
connector_id: int, connector_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Delete an MCP connector. Delete an MCP connector.
@ -2931,7 +2944,7 @@ async def delete_mcp_connector(
# Check user has permission to delete connectors # Check user has permission to delete connectors
await check_permission( await check_permission(
session, session,
user, auth,
connector.search_space_id, connector.search_space_id,
Permission.CONNECTORS_DELETE.value, Permission.CONNECTORS_DELETE.value,
"You don't have permission to delete this connector", "You don't have permission to delete this connector",
@ -2962,8 +2975,9 @@ async def delete_mcp_connector(
@router.post("/connectors/mcp/test") @router.post("/connectors/mcp/test")
async def test_mcp_server_connection( async def test_mcp_server_connection(
server_config: dict = Body(...), server_config: dict = Body(...),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
""" """
Test connection to an MCP server and fetch available tools. Test connection to an MCP server and fetch available tools.
@ -3042,8 +3056,9 @@ DRIVE_CONNECTOR_TYPES = {
async def get_drive_picker_token( async def get_drive_picker_token(
connector_id: int, connector_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Return an OAuth access token + client ID for the Google Picker API.""" """Return an OAuth access token + client ID for the Google Picker API."""
result = await session.execute( result = await session.execute(
select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id) select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id)
@ -3054,7 +3069,7 @@ async def get_drive_picker_token(
await check_permission( await check_permission(
session, session,
user, auth,
connector.search_space_id, connector.search_space_id,
Permission.CONNECTORS_READ.value, Permission.CONNECTORS_READ.value,
"You don't have permission to access this connector", "You don't have permission to access this connector",
@ -3164,8 +3179,9 @@ async def trust_mcp_tool(
connector_id: int, connector_id: int,
body: MCPTrustToolRequest, body: MCPTrustToolRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Add a tool to the MCP connector's trusted (always-allow) list. """Add a tool to the MCP connector's trusted (always-allow) list.
Once trusted, the tool executes without HITL approval on subsequent Once trusted, the tool executes without HITL approval on subsequent
@ -3209,8 +3225,9 @@ async def untrust_mcp_tool(
connector_id: int, connector_id: int,
body: MCPTrustToolRequest, body: MCPTrustToolRequest,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
user = auth.user
"""Remove a tool from the MCP connector's trusted list. """Remove a tool from the MCP connector's trusted list.
The tool will require HITL approval again on subsequent calls. The tool will require HITL approval again on subsequent calls.

View file

@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import User, get_async_session from app.db import User, get_async_session
from app.services.memory import ( from app.services.memory import (
MemoryRead, MemoryRead,
@ -15,7 +16,7 @@ from app.services.memory import (
reset_memory, reset_memory,
save_memory, save_memory,
) )
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_search_space_access from app.utils.rbac import check_search_space_access
router = APIRouter() router = APIRouter()
@ -29,9 +30,10 @@ class TeamMemoryUpdate(BaseModel):
async def get_team_memory( async def get_team_memory(
search_space_id: int, search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
await check_search_space_access(session, user, search_space_id) user = auth.user
await check_search_space_access(session, auth, search_space_id)
memory_md = await read_memory( memory_md = await read_memory(
scope=MemoryScope.TEAM, scope=MemoryScope.TEAM,
target_id=search_space_id, target_id=search_space_id,
@ -45,9 +47,10 @@ async def update_team_memory(
search_space_id: int, search_space_id: int,
body: TeamMemoryUpdate, body: TeamMemoryUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
await check_search_space_access(session, user, search_space_id) user = auth.user
await check_search_space_access(session, auth, search_space_id)
result = await save_memory( result = await save_memory(
scope=MemoryScope.TEAM, scope=MemoryScope.TEAM,
target_id=search_space_id, target_id=search_space_id,
@ -63,9 +66,10 @@ async def update_team_memory(
async def reset_team_memory( async def reset_team_memory(
search_space_id: int, search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
): ):
await check_search_space_access(session, user, search_space_id) user = auth.user
await check_search_space_access(session, auth, search_space_id)
result = await reset_memory( result = await reset_memory(
scope=MemoryScope.TEAM, scope=MemoryScope.TEAM,
target_id=search_space_id, target_id=search_space_id,