From 6fd3f8570e05909be754510800bc4e9dfd123365 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Fri, 19 Jun 2026 21:04:21 +0530 Subject: [PATCH] refactor: streamline auth context usage across chat and automation routes --- .../main_agent/tools/automation/create.py | 22 +++++----- surfsense_backend/app/gateway/agent_invoke.py | 3 ++ .../app/routes/new_chat_routes.py | 5 ++- .../app/routes/obsidian_plugin_routes.py | 40 +++++++++++-------- .../app/tasks/chat/streaming/agent/builder.py | 3 ++ .../streaming/flows/new_chat/orchestrator.py | 4 ++ .../flows/resume_chat/orchestrator.py | 4 ++ 7 files changed, 53 insertions(+), 28 deletions(-) diff --git a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py index c14413cf4..fe42410ed 100644 --- a/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py +++ b/surfsense_backend/app/agents/chat/multi_agent_chat/main_agent/tools/automation/create.py @@ -33,7 +33,7 @@ from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated from app.auth.context import AuthContext from app.automations.schemas.api import AutomationCreate from app.automations.services.automation import AutomationService -from app.db import User, async_session_maker +from app.db import async_session_maker from app.utils.content_utils import extract_text_content from .prompt import build_draft_prompt @@ -58,8 +58,6 @@ def create_create_automation_tool( ``AsyncSession`` is opened per call to avoid stale sessions on compiled-agent cache hits (same pattern as the Notion / memory tools). """ - uid = UUID(user_id) if isinstance(user_id, str) else user_id - @tool async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]: """Draft + save an automation from a natural-language intent. @@ -167,15 +165,17 @@ def create_create_automation_tool( "issues": _format_validation_issues(exc), } + if auth_context is None: + logger.error( + "create_automation called without AuthContext; refusing to persist" + ) + return { + "status": "error", + "message": "authorization context missing for automation creation", + } + async with async_session_maker() as session: - user = await session.get(User, uid) - if user is None: - return { - "status": "error", - "message": "user not found in this session", - } - auth = auth_context or AuthContext.system(user, source="agent") - service = AutomationService(session=session, auth=auth) + service = AutomationService(session=session, auth=auth_context) created = await service.create(final_validated) return { "status": "saved", diff --git a/surfsense_backend/app/gateway/agent_invoke.py b/surfsense_backend/app/gateway/agent_invoke.py index 8701ccc55..e03ea8c8b 100644 --- a/surfsense_backend/app/gateway/agent_invoke.py +++ b/surfsense_backend/app/gateway/agent_invoke.py @@ -9,6 +9,7 @@ from collections.abc import AsyncIterator from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession +from app.auth.context import AuthContext from app.db import ExternalChatBinding, NewChatMessage from app.gateway.auth_invariant import assert_authorization_invariant from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent @@ -64,6 +65,7 @@ async def call_agent_for_gateway( request_id: str | None = None, ) -> None: user = await assert_authorization_invariant(session, binding) + auth_context = AuthContext.system(user, source="gateway") thread = await get_or_create_thread_for_binding(session, binding) await session.commit() @@ -81,6 +83,7 @@ async def call_agent_for_gateway( current_user_display_name=user.display_name or "A team member", disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES), request_id=request_id or "gateway", + auth_context=auth_context, ) events = _events_from_sse(stream) try: diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index d76211dfc..1ca598fe3 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -24,7 +24,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload -from app.auth.context import AuthContext from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( get_cancel_state, is_cancel_requested, @@ -37,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemSelection, LocalFilesystemMount, ) +from app.auth.context import AuthContext from app.config import config from app.db import ( ChatComment, @@ -1810,6 +1810,7 @@ async def handle_new_chat( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=image_urls, + auth_context=auth, ), media_type="text/event-stream", headers={ @@ -2306,6 +2307,7 @@ async def regenerate_response( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), user_image_data_urls=regenerate_image_urls or None, + auth_context=auth, flow="regenerate", ): yield chunk @@ -2432,6 +2434,7 @@ async def resume_chat( filesystem_selection=filesystem_selection, request_id=getattr(http_request.state, "request_id", "unknown"), disabled_tools=request.disabled_tools, + auth_context=auth, ), media_type="text/event-stream", headers={ diff --git a/surfsense_backend/app/routes/obsidian_plugin_routes.py b/surfsense_backend/app/routes/obsidian_plugin_routes.py index bd54a4788..42ac110d3 100644 --- a/surfsense_backend/app/routes/obsidian_plugin_routes.py +++ b/surfsense_backend/app/routes/obsidian_plugin_routes.py @@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from app.auth.context import AuthContext from app.db import ( Document, DocumentType, @@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import ( upsert_note, ) from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task -from app.users import current_active_user +from app.users import get_auth_context +from app.utils.rbac import check_search_space_access logger = logging.getLogger(__name__) @@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification( async def _resolve_vault_connector( session: AsyncSession, *, - user: User, + auth: AuthContext, vault_id: str, ) -> SearchSourceConnector: """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.""" + user = auth.user # ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the # cross-dialect equivalent of ``.astext`` and compiles to ``->>``. stmt = select(SearchSourceConnector).where( @@ -192,6 +195,7 @@ async def _resolve_vault_connector( connector = (await session.execute(stmt)).scalars().first() if connector is not None: + await check_search_space_access(session, auth, connector.search_space_id) return connector raise HTTPException( @@ -221,10 +225,11 @@ def _queue_obsidian_attachment( async def _ensure_search_space_access( session: AsyncSession, *, - user: User, + auth: AuthContext, search_space_id: int, ) -> SearchSpace: """Owner-only access to the search space (shared spaces are a follow-up).""" + user = auth.user result = await session.execute( select(SearchSpace).where( and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id) @@ -239,6 +244,7 @@ async def _ensure_search_space_access( "message": "You don't own that search space.", }, ) + await check_search_space_access(session, auth, search_space_id) return space @@ -249,7 +255,7 @@ async def _ensure_search_space_access( @router.get("/health", response_model=HealthResponse) async def obsidian_health( - user: User = Depends(current_active_user), + _auth: AuthContext = Depends(get_auth_context), ) -> HealthResponse: """Return the API contract handshake; plugin caches it per onload.""" return HealthResponse( @@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str: @router.post("/connect", response_model=ConnectResponse) async def obsidian_connect( payload: ConnectRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> ConnectResponse: """Register a vault, refresh an existing one, or adopt another device's row. @@ -321,8 +327,9 @@ async def obsidian_connect( the partial unique index can never produce two live rows for one vault. """ await _ensure_search_space_access( - session, user=user, search_space_id=payload.search_space_id + session, auth=auth, search_space_id=payload.search_space_id ) + user = auth.user now_iso = datetime.now(UTC).isoformat() cfg = _build_config(payload, now_iso=now_iso) @@ -445,13 +452,14 @@ async def obsidian_connect( @router.post("/sync", response_model=SyncAck) async def obsidian_sync( payload: SyncBatchRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> SyncAck: """Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, auth=auth, vault_id=payload.vault_id ) + user = auth.user notification = None try: notification = await _start_obsidian_sync_notification( @@ -551,12 +559,12 @@ async def obsidian_sync( @router.post("/rename", response_model=RenameAck) async def obsidian_rename( payload: RenameBatchRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> RenameAck: """Apply a batch of vault rename events.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, auth=auth, vault_id=payload.vault_id ) items: list[RenameAckItem] = [] @@ -618,12 +626,12 @@ async def obsidian_rename( @router.delete("/notes", response_model=DeleteAck) async def obsidian_delete_notes( payload: DeleteBatchRequest, - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> DeleteAck: """Soft-delete a batch of notes by vault-relative path.""" connector = await _resolve_vault_connector( - session, user=user, vault_id=payload.vault_id + session, auth=auth, vault_id=payload.vault_id ) deleted = 0 @@ -662,18 +670,18 @@ async def obsidian_delete_notes( @router.get("/manifest", response_model=ManifestResponse) async def obsidian_manifest( vault_id: str = Query(..., description="Plugin-side stable vault UUID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> ManifestResponse: """Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff.""" - connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) + connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id) return await get_manifest(session, connector=connector, vault_id=vault_id) @router.get("/stats", response_model=StatsResponse) async def obsidian_stats( vault_id: str = Query(..., description="Plugin-side stable vault UUID"), - user: User = Depends(current_active_user), + auth: AuthContext = Depends(get_auth_context), session: AsyncSession = Depends(get_async_session), ) -> StatsResponse: """Active-note count + last sync time for the web tile. @@ -681,7 +689,7 @@ async def obsidian_stats( ``files_synced`` excludes tombstones so it matches ``/manifest``; ``last_sync_at`` includes them so deletes advance the freshness signal. """ - connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) + connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id) is_active = Document.document_metadata["deleted_at"].as_string().is_(None) diff --git a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py index dcbd37521..9d7d1b0c5 100644 --- a/surfsense_backend/app/tasks/chat/streaming/agent/builder.py +++ b/surfsense_backend/app/tasks/chat/streaming/agent/builder.py @@ -13,6 +13,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemSelection, ) from app.agents.chat.runtime.llm_config import AgentConfig +from app.auth.context import AuthContext from app.db import ChatVisibility from app.services.connector_service import ConnectorService @@ -33,6 +34,7 @@ async def build_main_agent_for_thread( filesystem_selection: FilesystemSelection | None, disabled_tools: list[str] | None = None, mentioned_document_ids: list[int] | None = None, + auth_context: AuthContext | None = None, ) -> Any: return await agent_factory( llm=llm, @@ -48,4 +50,5 @@ async def build_main_agent_for_thread( filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + auth_context=auth_context, ) diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py index 1e6097e53..69343ffa4 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/new_chat/orchestrator.py @@ -35,6 +35,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemMode, FilesystemSelection, ) +from app.auth.context import AuthContext from app.db import ChatVisibility, async_session_maker from app.observability import otel as ot from app.services.new_streaming_service import VercelStreamingService @@ -136,6 +137,7 @@ async def stream_new_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, user_image_data_urls: list[str] | None = None, + auth_context: AuthContext | None = None, flow: Literal["new", "regenerate"] = "new", ) -> AsyncGenerator[str, None]: """Stream a new chat turn using the SurfSense deep agent. @@ -412,6 +414,7 @@ async def stream_new_chat( filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + auth_context=auth_context, ) _perf_log.info( "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 @@ -664,6 +667,7 @@ async def stream_new_chat( filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, mentioned_document_ids=mentioned_document_ids, + auth_context=auth_context, ) _perf_log.info( "[stream_new_chat] Runtime rate-limit recovery repinned " diff --git a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py index e1552e79e..33fcee3da 100644 --- a/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py +++ b/surfsense_backend/app/tasks/chat/streaming/flows/resume_chat/orchestrator.py @@ -29,6 +29,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import ( FilesystemMode, FilesystemSelection, ) +from app.auth.context import AuthContext from app.db import ChatVisibility, async_session_maker from app.observability import otel as ot from app.services.chat_session_state_service import set_ai_responding @@ -102,6 +103,7 @@ async def stream_resume_chat( filesystem_selection: FilesystemSelection | None = None, request_id: str | None = None, disabled_tools: list[str] | None = None, + auth_context: AuthContext | None = None, ) -> AsyncGenerator[str, None]: """Resume a paused HITL turn with the user's decisions. @@ -346,6 +348,7 @@ async def stream_resume_chat( thread_visibility=visibility, filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, + auth_context=auth_context, ) _perf_log.info( "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 @@ -481,6 +484,7 @@ async def stream_resume_chat( thread_visibility=visibility, filesystem_selection=filesystem_selection, disabled_tools=disabled_tools, + auth_context=auth_context, ) _perf_log.info( "[stream_resume] Runtime rate-limit recovery repinned "