refactor: streamline auth context usage across chat and automation routes

This commit is contained in:
Anish Sarkar 2026-06-19 21:04:21 +05:30
parent 8af4a3f9d5
commit 6fd3f8570e
7 changed files with 53 additions and 28 deletions

View file

@ -33,7 +33,7 @@ from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated
from app.auth.context import AuthContext from app.auth.context import AuthContext
from app.automations.schemas.api import AutomationCreate from app.automations.schemas.api import AutomationCreate
from app.automations.services.automation import AutomationService from app.automations.services.automation import AutomationService
from app.db import User, async_session_maker from app.db import async_session_maker
from app.utils.content_utils import extract_text_content from app.utils.content_utils import extract_text_content
from .prompt import build_draft_prompt from .prompt import build_draft_prompt
@ -58,8 +58,6 @@ def create_create_automation_tool(
``AsyncSession`` is opened per call to avoid stale sessions on ``AsyncSession`` is opened per call to avoid stale sessions on
compiled-agent cache hits (same pattern as the Notion / memory tools). compiled-agent cache hits (same pattern as the Notion / memory tools).
""" """
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool @tool
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]: async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
"""Draft + save an automation from a natural-language intent. """Draft + save an automation from a natural-language intent.
@ -167,15 +165,17 @@ def create_create_automation_tool(
"issues": _format_validation_issues(exc), "issues": _format_validation_issues(exc),
} }
if auth_context is None:
logger.error(
"create_automation called without AuthContext; refusing to persist"
)
return {
"status": "error",
"message": "authorization context missing for automation creation",
}
async with async_session_maker() as session: async with async_session_maker() as session:
user = await session.get(User, uid) service = AutomationService(session=session, auth=auth_context)
if user is None:
return {
"status": "error",
"message": "user not found in this session",
}
auth = auth_context or AuthContext.system(user, source="agent")
service = AutomationService(session=session, auth=auth)
created = await service.create(final_validated) created = await service.create(final_validated)
return { return {
"status": "saved", "status": "saved",

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
from sqlalchemy import update from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ExternalChatBinding, NewChatMessage from app.db import ExternalChatBinding, NewChatMessage
from app.gateway.auth_invariant import assert_authorization_invariant from app.gateway.auth_invariant import assert_authorization_invariant
from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent
@ -64,6 +65,7 @@ async def call_agent_for_gateway(
request_id: str | None = None, request_id: str | None = None,
) -> None: ) -> None:
user = await assert_authorization_invariant(session, binding) user = await assert_authorization_invariant(session, binding)
auth_context = AuthContext.system(user, source="gateway")
thread = await get_or_create_thread_for_binding(session, binding) thread = await get_or_create_thread_for_binding(session, binding)
await session.commit() await session.commit()
@ -81,6 +83,7 @@ async def call_agent_for_gateway(
current_user_display_name=user.display_name or "A team member", current_user_display_name=user.display_name or "A team member",
disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES), disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES),
request_id=request_id or "gateway", request_id=request_id or "gateway",
auth_context=auth_context,
) )
events = _events_from_sse(stream) events = _events_from_sse(stream)
try: try:

View file

@ -24,7 +24,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import ( from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
get_cancel_state, get_cancel_state,
is_cancel_requested, is_cancel_requested,
@ -37,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection, FilesystemSelection,
LocalFilesystemMount, LocalFilesystemMount,
) )
from app.auth.context import AuthContext
from app.config import config from app.config import config
from app.db import ( from app.db import (
ChatComment, ChatComment,
@ -1810,6 +1810,7 @@ async def handle_new_chat(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"), request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls, user_image_data_urls=image_urls,
auth_context=auth,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={
@ -2306,6 +2307,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"), request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None, user_image_data_urls=regenerate_image_urls or None,
auth_context=auth,
flow="regenerate", flow="regenerate",
): ):
yield chunk yield chunk
@ -2432,6 +2434,7 @@ async def resume_chat(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"), request_id=getattr(http_request.state, "request_id", "unknown"),
disabled_tools=request.disabled_tools, disabled_tools=request.disabled_tools,
auth_context=auth,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={

View file

@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import ( from app.db import (
Document, Document,
DocumentType, DocumentType,
@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import (
upsert_note, upsert_note,
) )
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
from app.users import current_active_user from app.users import get_auth_context
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification(
async def _resolve_vault_connector( async def _resolve_vault_connector(
session: AsyncSession, session: AsyncSession,
*, *,
user: User, auth: AuthContext,
vault_id: str, vault_id: str,
) -> SearchSourceConnector: ) -> SearchSourceConnector:
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user.""" """Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
user = auth.user
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the # ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``. # cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
stmt = select(SearchSourceConnector).where( stmt = select(SearchSourceConnector).where(
@ -192,6 +195,7 @@ async def _resolve_vault_connector(
connector = (await session.execute(stmt)).scalars().first() connector = (await session.execute(stmt)).scalars().first()
if connector is not None: if connector is not None:
await check_search_space_access(session, auth, connector.search_space_id)
return connector return connector
raise HTTPException( raise HTTPException(
@ -221,10 +225,11 @@ def _queue_obsidian_attachment(
async def _ensure_search_space_access( async def _ensure_search_space_access(
session: AsyncSession, session: AsyncSession,
*, *,
user: User, auth: AuthContext,
search_space_id: int, search_space_id: int,
) -> SearchSpace: ) -> SearchSpace:
"""Owner-only access to the search space (shared spaces are a follow-up).""" """Owner-only access to the search space (shared spaces are a follow-up)."""
user = auth.user
result = await session.execute( result = await session.execute(
select(SearchSpace).where( select(SearchSpace).where(
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id) and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
@ -239,6 +244,7 @@ async def _ensure_search_space_access(
"message": "You don't own that search space.", "message": "You don't own that search space.",
}, },
) )
await check_search_space_access(session, auth, search_space_id)
return space return space
@ -249,7 +255,7 @@ async def _ensure_search_space_access(
@router.get("/health", response_model=HealthResponse) @router.get("/health", response_model=HealthResponse)
async def obsidian_health( async def obsidian_health(
user: User = Depends(current_active_user), _auth: AuthContext = Depends(get_auth_context),
) -> HealthResponse: ) -> HealthResponse:
"""Return the API contract handshake; plugin caches it per onload.""" """Return the API contract handshake; plugin caches it per onload."""
return HealthResponse( return HealthResponse(
@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str:
@router.post("/connect", response_model=ConnectResponse) @router.post("/connect", response_model=ConnectResponse)
async def obsidian_connect( async def obsidian_connect(
payload: ConnectRequest, payload: ConnectRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> ConnectResponse: ) -> ConnectResponse:
"""Register a vault, refresh an existing one, or adopt another device's row. """Register a vault, refresh an existing one, or adopt another device's row.
@ -321,8 +327,9 @@ async def obsidian_connect(
the partial unique index can never produce two live rows for one vault. the partial unique index can never produce two live rows for one vault.
""" """
await _ensure_search_space_access( await _ensure_search_space_access(
session, user=user, search_space_id=payload.search_space_id session, auth=auth, search_space_id=payload.search_space_id
) )
user = auth.user
now_iso = datetime.now(UTC).isoformat() now_iso = datetime.now(UTC).isoformat()
cfg = _build_config(payload, now_iso=now_iso) cfg = _build_config(payload, now_iso=now_iso)
@ -445,13 +452,14 @@ async def obsidian_connect(
@router.post("/sync", response_model=SyncAck) @router.post("/sync", response_model=SyncAck)
async def obsidian_sync( async def obsidian_sync(
payload: SyncBatchRequest, payload: SyncBatchRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> SyncAck: ) -> SyncAck:
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry.""" """Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
connector = await _resolve_vault_connector( connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id session, auth=auth, vault_id=payload.vault_id
) )
user = auth.user
notification = None notification = None
try: try:
notification = await _start_obsidian_sync_notification( notification = await _start_obsidian_sync_notification(
@ -551,12 +559,12 @@ async def obsidian_sync(
@router.post("/rename", response_model=RenameAck) @router.post("/rename", response_model=RenameAck)
async def obsidian_rename( async def obsidian_rename(
payload: RenameBatchRequest, payload: RenameBatchRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> RenameAck: ) -> RenameAck:
"""Apply a batch of vault rename events.""" """Apply a batch of vault rename events."""
connector = await _resolve_vault_connector( connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id session, auth=auth, vault_id=payload.vault_id
) )
items: list[RenameAckItem] = [] items: list[RenameAckItem] = []
@ -618,12 +626,12 @@ async def obsidian_rename(
@router.delete("/notes", response_model=DeleteAck) @router.delete("/notes", response_model=DeleteAck)
async def obsidian_delete_notes( async def obsidian_delete_notes(
payload: DeleteBatchRequest, payload: DeleteBatchRequest,
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> DeleteAck: ) -> DeleteAck:
"""Soft-delete a batch of notes by vault-relative path.""" """Soft-delete a batch of notes by vault-relative path."""
connector = await _resolve_vault_connector( connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id session, auth=auth, vault_id=payload.vault_id
) )
deleted = 0 deleted = 0
@ -662,18 +670,18 @@ async def obsidian_delete_notes(
@router.get("/manifest", response_model=ManifestResponse) @router.get("/manifest", response_model=ManifestResponse)
async def obsidian_manifest( async def obsidian_manifest(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"), vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> ManifestResponse: ) -> ManifestResponse:
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff.""" """Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
return await get_manifest(session, connector=connector, vault_id=vault_id) return await get_manifest(session, connector=connector, vault_id=vault_id)
@router.get("/stats", response_model=StatsResponse) @router.get("/stats", response_model=StatsResponse)
async def obsidian_stats( async def obsidian_stats(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"), vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user), auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
) -> StatsResponse: ) -> StatsResponse:
"""Active-note count + last sync time for the web tile. """Active-note count + last sync time for the web tile.
@ -681,7 +689,7 @@ async def obsidian_stats(
``files_synced`` excludes tombstones so it matches ``/manifest``; ``files_synced`` excludes tombstones so it matches ``/manifest``;
``last_sync_at`` includes them so deletes advance the freshness signal. ``last_sync_at`` includes them so deletes advance the freshness signal.
""" """
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id) connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
is_active = Document.document_metadata["deleted_at"].as_string().is_(None) is_active = Document.document_metadata["deleted_at"].as_string().is_(None)

View file

@ -13,6 +13,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection, FilesystemSelection,
) )
from app.agents.chat.runtime.llm_config import AgentConfig from app.agents.chat.runtime.llm_config import AgentConfig
from app.auth.context import AuthContext
from app.db import ChatVisibility from app.db import ChatVisibility
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
@ -33,6 +34,7 @@ async def build_main_agent_for_thread(
filesystem_selection: FilesystemSelection | None, filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None, disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None, mentioned_document_ids: list[int] | None = None,
auth_context: AuthContext | None = None,
) -> Any: ) -> Any:
return await agent_factory( return await agent_factory(
llm=llm, llm=llm,
@ -48,4 +50,5 @@ async def build_main_agent_for_thread(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
) )

View file

@ -35,6 +35,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode, FilesystemMode,
FilesystemSelection, FilesystemSelection,
) )
from app.auth.context import AuthContext
from app.db import ChatVisibility, async_session_maker from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot from app.observability import otel as ot
from app.services.new_streaming_service import VercelStreamingService from app.services.new_streaming_service import VercelStreamingService
@ -136,6 +137,7 @@ async def stream_new_chat(
filesystem_selection: FilesystemSelection | None = None, filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None, request_id: str | None = None,
user_image_data_urls: list[str] | None = None, user_image_data_urls: list[str] | None = None,
auth_context: AuthContext | None = None,
flow: Literal["new", "regenerate"] = "new", flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Stream a new chat turn using the SurfSense deep agent. """Stream a new chat turn using the SurfSense deep agent.
@ -412,6 +414,7 @@ async def stream_new_chat(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
) )
_perf_log.info( _perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0 "[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
@ -664,6 +667,7 @@ async def stream_new_chat(
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids, mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
) )
_perf_log.info( _perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned " "[stream_new_chat] Runtime rate-limit recovery repinned "

View file

@ -29,6 +29,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode, FilesystemMode,
FilesystemSelection, FilesystemSelection,
) )
from app.auth.context import AuthContext
from app.db import ChatVisibility, async_session_maker from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot from app.observability import otel as ot
from app.services.chat_session_state_service import set_ai_responding from app.services.chat_session_state_service import set_ai_responding
@ -102,6 +103,7 @@ async def stream_resume_chat(
filesystem_selection: FilesystemSelection | None = None, filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None, request_id: str | None = None,
disabled_tools: list[str] | None = None, disabled_tools: list[str] | None = None,
auth_context: AuthContext | None = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Resume a paused HITL turn with the user's decisions. """Resume a paused HITL turn with the user's decisions.
@ -346,6 +348,7 @@ async def stream_resume_chat(
thread_visibility=visibility, thread_visibility=visibility,
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
auth_context=auth_context,
) )
_perf_log.info( _perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0 "[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -481,6 +484,7 @@ async def stream_resume_chat(
thread_visibility=visibility, thread_visibility=visibility,
filesystem_selection=filesystem_selection, filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,
auth_context=auth_context,
) )
_perf_log.info( _perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned " "[stream_resume] Runtime rate-limit recovery repinned "