mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-02 22:01:05 +02:00
refactor: streamline auth context usage across chat and automation routes
This commit is contained in:
parent
8af4a3f9d5
commit
6fd3f8570e
7 changed files with 53 additions and 28 deletions
|
|
@ -24,7 +24,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
|
|
@ -37,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
|
|||
FilesystemSelection,
|
||||
LocalFilesystemMount,
|
||||
)
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
|
|
@ -1810,6 +1810,7 @@ async def handle_new_chat(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=image_urls,
|
||||
auth_context=auth,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -2306,6 +2307,7 @@ async def regenerate_response(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
user_image_data_urls=regenerate_image_urls or None,
|
||||
auth_context=auth,
|
||||
flow="regenerate",
|
||||
):
|
||||
yield chunk
|
||||
|
|
@ -2432,6 +2434,7 @@ async def resume_chat(
|
|||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
disabled_tools=request.disabled_tools,
|
||||
auth_context=auth,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
|
|
@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import (
|
|||
upsert_note,
|
||||
)
|
||||
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_search_space_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification(
|
|||
async def _resolve_vault_connector(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
vault_id: str,
|
||||
) -> SearchSourceConnector:
|
||||
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
|
||||
user = auth.user
|
||||
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
|
||||
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
|
||||
stmt = select(SearchSourceConnector).where(
|
||||
|
|
@ -192,6 +195,7 @@ async def _resolve_vault_connector(
|
|||
|
||||
connector = (await session.execute(stmt)).scalars().first()
|
||||
if connector is not None:
|
||||
await check_search_space_access(session, auth, connector.search_space_id)
|
||||
return connector
|
||||
|
||||
raise HTTPException(
|
||||
|
|
@ -221,10 +225,11 @@ def _queue_obsidian_attachment(
|
|||
async def _ensure_search_space_access(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int,
|
||||
) -> SearchSpace:
|
||||
"""Owner-only access to the search space (shared spaces are a follow-up)."""
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(SearchSpace).where(
|
||||
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
|
||||
|
|
@ -239,6 +244,7 @@ async def _ensure_search_space_access(
|
|||
"message": "You don't own that search space.",
|
||||
},
|
||||
)
|
||||
await check_search_space_access(session, auth, search_space_id)
|
||||
return space
|
||||
|
||||
|
||||
|
|
@ -249,7 +255,7 @@ async def _ensure_search_space_access(
|
|||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def obsidian_health(
|
||||
user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(get_auth_context),
|
||||
) -> HealthResponse:
|
||||
"""Return the API contract handshake; plugin caches it per onload."""
|
||||
return HealthResponse(
|
||||
|
|
@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str:
|
|||
@router.post("/connect", response_model=ConnectResponse)
|
||||
async def obsidian_connect(
|
||||
payload: ConnectRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> ConnectResponse:
|
||||
"""Register a vault, refresh an existing one, or adopt another device's row.
|
||||
|
|
@ -321,8 +327,9 @@ async def obsidian_connect(
|
|||
the partial unique index can never produce two live rows for one vault.
|
||||
"""
|
||||
await _ensure_search_space_access(
|
||||
session, user=user, search_space_id=payload.search_space_id
|
||||
session, auth=auth, search_space_id=payload.search_space_id
|
||||
)
|
||||
user = auth.user
|
||||
|
||||
now_iso = datetime.now(UTC).isoformat()
|
||||
cfg = _build_config(payload, now_iso=now_iso)
|
||||
|
|
@ -445,13 +452,14 @@ async def obsidian_connect(
|
|||
@router.post("/sync", response_model=SyncAck)
|
||||
async def obsidian_sync(
|
||||
payload: SyncBatchRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> SyncAck:
|
||||
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
|
||||
connector = await _resolve_vault_connector(
|
||||
session, user=user, vault_id=payload.vault_id
|
||||
session, auth=auth, vault_id=payload.vault_id
|
||||
)
|
||||
user = auth.user
|
||||
notification = None
|
||||
try:
|
||||
notification = await _start_obsidian_sync_notification(
|
||||
|
|
@ -551,12 +559,12 @@ async def obsidian_sync(
|
|||
@router.post("/rename", response_model=RenameAck)
|
||||
async def obsidian_rename(
|
||||
payload: RenameBatchRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> RenameAck:
|
||||
"""Apply a batch of vault rename events."""
|
||||
connector = await _resolve_vault_connector(
|
||||
session, user=user, vault_id=payload.vault_id
|
||||
session, auth=auth, vault_id=payload.vault_id
|
||||
)
|
||||
|
||||
items: list[RenameAckItem] = []
|
||||
|
|
@ -618,12 +626,12 @@ async def obsidian_rename(
|
|||
@router.delete("/notes", response_model=DeleteAck)
|
||||
async def obsidian_delete_notes(
|
||||
payload: DeleteBatchRequest,
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> DeleteAck:
|
||||
"""Soft-delete a batch of notes by vault-relative path."""
|
||||
connector = await _resolve_vault_connector(
|
||||
session, user=user, vault_id=payload.vault_id
|
||||
session, auth=auth, vault_id=payload.vault_id
|
||||
)
|
||||
|
||||
deleted = 0
|
||||
|
|
@ -662,18 +670,18 @@ async def obsidian_delete_notes(
|
|||
@router.get("/manifest", response_model=ManifestResponse)
|
||||
async def obsidian_manifest(
|
||||
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> ManifestResponse:
|
||||
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
|
||||
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
|
||||
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
|
||||
return await get_manifest(session, connector=connector, vault_id=vault_id)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=StatsResponse)
|
||||
async def obsidian_stats(
|
||||
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> StatsResponse:
|
||||
"""Active-note count + last sync time for the web tile.
|
||||
|
|
@ -681,7 +689,7 @@ async def obsidian_stats(
|
|||
``files_synced`` excludes tombstones so it matches ``/manifest``;
|
||||
``last_sync_at`` includes them so deletes advance the freshness signal.
|
||||
"""
|
||||
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
|
||||
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
|
||||
|
||||
is_active = Document.document_metadata["deleted_at"].as_string().is_(None)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue