refactor: streamline auth context usage across chat and automation routes

This commit is contained in:
Anish Sarkar 2026-06-19 21:04:21 +05:30
parent 8af4a3f9d5
commit 6fd3f8570e
7 changed files with 53 additions and 28 deletions

View file

@ -24,7 +24,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
@ -37,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.auth.context import AuthContext
from app.config import config
from app.db import (
ChatComment,
@ -1810,6 +1810,7 @@ async def handle_new_chat(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls,
auth_context=auth,
),
media_type="text/event-stream",
headers={
@ -2306,6 +2307,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
auth_context=auth,
flow="regenerate",
):
yield chunk
@ -2432,6 +2434,7 @@ async def resume_chat(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
disabled_tools=request.disabled_tools,
auth_context=auth,
),
media_type="text/event-stream",
headers={

View file

@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
Document,
DocumentType,
@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import (
upsert_note,
)
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification(
async def _resolve_vault_connector(
session: AsyncSession,
*,
user: User,
auth: AuthContext,
vault_id: str,
) -> SearchSourceConnector:
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
user = auth.user
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
stmt = select(SearchSourceConnector).where(
@ -192,6 +195,7 @@ async def _resolve_vault_connector(
connector = (await session.execute(stmt)).scalars().first()
if connector is not None:
await check_search_space_access(session, auth, connector.search_space_id)
return connector
raise HTTPException(
@ -221,10 +225,11 @@ def _queue_obsidian_attachment(
async def _ensure_search_space_access(
session: AsyncSession,
*,
user: User,
auth: AuthContext,
search_space_id: int,
) -> SearchSpace:
"""Owner-only access to the search space (shared spaces are a follow-up)."""
user = auth.user
result = await session.execute(
select(SearchSpace).where(
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
@ -239,6 +244,7 @@ async def _ensure_search_space_access(
"message": "You don't own that search space.",
},
)
await check_search_space_access(session, auth, search_space_id)
return space
@ -249,7 +255,7 @@ async def _ensure_search_space_access(
@router.get("/health", response_model=HealthResponse)
async def obsidian_health(
user: User = Depends(current_active_user),
_auth: AuthContext = Depends(get_auth_context),
) -> HealthResponse:
"""Return the API contract handshake; plugin caches it per onload."""
return HealthResponse(
@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str:
@router.post("/connect", response_model=ConnectResponse)
async def obsidian_connect(
payload: ConnectRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ConnectResponse:
"""Register a vault, refresh an existing one, or adopt another device's row.
@ -321,8 +327,9 @@ async def obsidian_connect(
the partial unique index can never produce two live rows for one vault.
"""
await _ensure_search_space_access(
session, user=user, search_space_id=payload.search_space_id
session, auth=auth, search_space_id=payload.search_space_id
)
user = auth.user
now_iso = datetime.now(UTC).isoformat()
cfg = _build_config(payload, now_iso=now_iso)
@ -445,13 +452,14 @@ async def obsidian_connect(
@router.post("/sync", response_model=SyncAck)
async def obsidian_sync(
payload: SyncBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> SyncAck:
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
user = auth.user
notification = None
try:
notification = await _start_obsidian_sync_notification(
@ -551,12 +559,12 @@ async def obsidian_sync(
@router.post("/rename", response_model=RenameAck)
async def obsidian_rename(
payload: RenameBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> RenameAck:
"""Apply a batch of vault rename events."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
items: list[RenameAckItem] = []
@ -618,12 +626,12 @@ async def obsidian_rename(
@router.delete("/notes", response_model=DeleteAck)
async def obsidian_delete_notes(
payload: DeleteBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> DeleteAck:
"""Soft-delete a batch of notes by vault-relative path."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
deleted = 0
@ -662,18 +670,18 @@ async def obsidian_delete_notes(
@router.get("/manifest", response_model=ManifestResponse)
async def obsidian_manifest(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ManifestResponse:
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
return await get_manifest(session, connector=connector, vault_id=vault_id)
@router.get("/stats", response_model=StatsResponse)
async def obsidian_stats(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> StatsResponse:
"""Active-note count + last sync time for the web tile.
@ -681,7 +689,7 @@ async def obsidian_stats(
``files_synced`` excludes tombstones so it matches ``/manifest``;
``last_sync_at`` includes them so deletes advance the freshness signal.
"""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
is_active = Document.document_metadata["deleted_at"].as_string().is_(None)