refactor: streamline auth context usage across chat and automation routes

This commit is contained in:
Anish Sarkar 2026-06-19 21:04:21 +05:30
parent 8af4a3f9d5
commit 6fd3f8570e
7 changed files with 53 additions and 28 deletions

View file

@ -33,7 +33,7 @@ from app.agents.chat.multi_agent_chat.subagents.shared.hitl.approvals.self_gated
from app.auth.context import AuthContext
from app.automations.schemas.api import AutomationCreate
from app.automations.services.automation import AutomationService
from app.db import User, async_session_maker
from app.db import async_session_maker
from app.utils.content_utils import extract_text_content
from .prompt import build_draft_prompt
@ -58,8 +58,6 @@ def create_create_automation_tool(
``AsyncSession`` is opened per call to avoid stale sessions on
compiled-agent cache hits (same pattern as the Notion / memory tools).
"""
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def create_automation(intent: str, runtime: ToolRuntime) -> dict[str, Any]:
"""Draft + save an automation from a natural-language intent.
@ -167,15 +165,17 @@ def create_create_automation_tool(
"issues": _format_validation_issues(exc),
}
if auth_context is None:
logger.error(
"create_automation called without AuthContext; refusing to persist"
)
return {
"status": "error",
"message": "authorization context missing for automation creation",
}
async with async_session_maker() as session:
user = await session.get(User, uid)
if user is None:
return {
"status": "error",
"message": "user not found in this session",
}
auth = auth_context or AuthContext.system(user, source="agent")
service = AutomationService(session=session, auth=auth)
service = AutomationService(session=session, auth=auth_context)
created = await service.create(final_validated)
return {
"status": "saved",

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncIterator
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.context import AuthContext
from app.db import ExternalChatBinding, NewChatMessage
from app.gateway.auth_invariant import assert_authorization_invariant
from app.gateway.base.translator import BaseStreamTranslator, GatewayStreamEvent
@ -64,6 +65,7 @@ async def call_agent_for_gateway(
request_id: str | None = None,
) -> None:
user = await assert_authorization_invariant(session, binding)
auth_context = AuthContext.system(user, source="gateway")
thread = await get_or_create_thread_for_binding(session, binding)
await session.commit()
@ -81,6 +83,7 @@ async def call_agent_for_gateway(
current_user_display_name=user.display_name or "A team member",
disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES),
request_id=request_id or "gateway",
auth_context=auth_context,
)
events = _events_from_sse(stream)
try:

View file

@ -24,7 +24,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.auth.context import AuthContext
from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
@ -37,6 +36,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.auth.context import AuthContext
from app.config import config
from app.db import (
ChatComment,
@ -1810,6 +1810,7 @@ async def handle_new_chat(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls,
auth_context=auth,
),
media_type="text/event-stream",
headers={
@ -2306,6 +2307,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
auth_context=auth,
flow="regenerate",
):
yield chunk
@ -2432,6 +2434,7 @@ async def resume_chat(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
disabled_tools=request.disabled_tools,
auth_context=auth,
),
media_type="text/event-stream",
headers={

View file

@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.auth.context import AuthContext
from app.db import (
Document,
DocumentType,
@ -52,7 +53,8 @@ from app.services.obsidian_plugin_indexer import (
upsert_note,
)
from app.tasks.celery_tasks.obsidian_tasks import index_obsidian_attachment_task
from app.users import current_active_user
from app.users import get_auth_context
from app.utils.rbac import check_search_space_access
logger = logging.getLogger(__name__)
@ -174,10 +176,11 @@ async def _finish_obsidian_sync_notification(
async def _resolve_vault_connector(
session: AsyncSession,
*,
user: User,
auth: AuthContext,
vault_id: str,
) -> SearchSourceConnector:
"""Find the OBSIDIAN_CONNECTOR row that owns ``vault_id`` for this user."""
user = auth.user
# ``config`` is core ``JSON`` (not ``JSONB``); ``as_string()`` is the
# cross-dialect equivalent of ``.astext`` and compiles to ``->>``.
stmt = select(SearchSourceConnector).where(
@ -192,6 +195,7 @@ async def _resolve_vault_connector(
connector = (await session.execute(stmt)).scalars().first()
if connector is not None:
await check_search_space_access(session, auth, connector.search_space_id)
return connector
raise HTTPException(
@ -221,10 +225,11 @@ def _queue_obsidian_attachment(
async def _ensure_search_space_access(
session: AsyncSession,
*,
user: User,
auth: AuthContext,
search_space_id: int,
) -> SearchSpace:
"""Owner-only access to the search space (shared spaces are a follow-up)."""
user = auth.user
result = await session.execute(
select(SearchSpace).where(
and_(SearchSpace.id == search_space_id, SearchSpace.user_id == user.id)
@ -239,6 +244,7 @@ async def _ensure_search_space_access(
"message": "You don't own that search space.",
},
)
await check_search_space_access(session, auth, search_space_id)
return space
@ -249,7 +255,7 @@ async def _ensure_search_space_access(
@router.get("/health", response_model=HealthResponse)
async def obsidian_health(
user: User = Depends(current_active_user),
_auth: AuthContext = Depends(get_auth_context),
) -> HealthResponse:
"""Return the API contract handshake; plugin caches it per onload."""
return HealthResponse(
@ -306,7 +312,7 @@ def _display_name(vault_name: str) -> str:
@router.post("/connect", response_model=ConnectResponse)
async def obsidian_connect(
payload: ConnectRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ConnectResponse:
"""Register a vault, refresh an existing one, or adopt another device's row.
@ -321,8 +327,9 @@ async def obsidian_connect(
the partial unique index can never produce two live rows for one vault.
"""
await _ensure_search_space_access(
session, user=user, search_space_id=payload.search_space_id
session, auth=auth, search_space_id=payload.search_space_id
)
user = auth.user
now_iso = datetime.now(UTC).isoformat()
cfg = _build_config(payload, now_iso=now_iso)
@ -445,13 +452,14 @@ async def obsidian_connect(
@router.post("/sync", response_model=SyncAck)
async def obsidian_sync(
payload: SyncBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> SyncAck:
"""Batch-upsert notes; returns per-note ack so the plugin can dequeue/retry."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
user = auth.user
notification = None
try:
notification = await _start_obsidian_sync_notification(
@ -551,12 +559,12 @@ async def obsidian_sync(
@router.post("/rename", response_model=RenameAck)
async def obsidian_rename(
payload: RenameBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> RenameAck:
"""Apply a batch of vault rename events."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
items: list[RenameAckItem] = []
@ -618,12 +626,12 @@ async def obsidian_rename(
@router.delete("/notes", response_model=DeleteAck)
async def obsidian_delete_notes(
payload: DeleteBatchRequest,
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> DeleteAck:
"""Soft-delete a batch of notes by vault-relative path."""
connector = await _resolve_vault_connector(
session, user=user, vault_id=payload.vault_id
session, auth=auth, vault_id=payload.vault_id
)
deleted = 0
@ -662,18 +670,18 @@ async def obsidian_delete_notes(
@router.get("/manifest", response_model=ManifestResponse)
async def obsidian_manifest(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> ManifestResponse:
"""Return ``{path: {hash, mtime}}`` for the plugin's onload reconcile diff."""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
return await get_manifest(session, connector=connector, vault_id=vault_id)
@router.get("/stats", response_model=StatsResponse)
async def obsidian_stats(
vault_id: str = Query(..., description="Plugin-side stable vault UUID"),
user: User = Depends(current_active_user),
auth: AuthContext = Depends(get_auth_context),
session: AsyncSession = Depends(get_async_session),
) -> StatsResponse:
"""Active-note count + last sync time for the web tile.
@ -681,7 +689,7 @@ async def obsidian_stats(
``files_synced`` excludes tombstones so it matches ``/manifest``;
``last_sync_at`` includes them so deletes advance the freshness signal.
"""
connector = await _resolve_vault_connector(session, user=user, vault_id=vault_id)
connector = await _resolve_vault_connector(session, auth=auth, vault_id=vault_id)
is_active = Document.document_metadata["deleted_at"].as_string().is_(None)

View file

@ -13,6 +13,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemSelection,
)
from app.agents.chat.runtime.llm_config import AgentConfig
from app.auth.context import AuthContext
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
@ -33,6 +34,7 @@ async def build_main_agent_for_thread(
filesystem_selection: FilesystemSelection | None,
disabled_tools: list[str] | None = None,
mentioned_document_ids: list[int] | None = None,
auth_context: AuthContext | None = None,
) -> Any:
return await agent_factory(
llm=llm,
@ -48,4 +50,5 @@ async def build_main_agent_for_thread(
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
)

View file

@ -35,6 +35,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode,
FilesystemSelection,
)
from app.auth.context import AuthContext
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.new_streaming_service import VercelStreamingService
@ -136,6 +137,7 @@ async def stream_new_chat(
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
user_image_data_urls: list[str] | None = None,
auth_context: AuthContext | None = None,
flow: Literal["new", "regenerate"] = "new",
) -> AsyncGenerator[str, None]:
"""Stream a new chat turn using the SurfSense deep agent.
@ -412,6 +414,7 @@ async def stream_new_chat(
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
)
_perf_log.info(
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
@ -664,6 +667,7 @@ async def stream_new_chat(
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
mentioned_document_ids=mentioned_document_ids,
auth_context=auth_context,
)
_perf_log.info(
"[stream_new_chat] Runtime rate-limit recovery repinned "

View file

@ -29,6 +29,7 @@ from app.agents.chat.multi_agent_chat.shared.filesystem_selection import (
FilesystemMode,
FilesystemSelection,
)
from app.auth.context import AuthContext
from app.db import ChatVisibility, async_session_maker
from app.observability import otel as ot
from app.services.chat_session_state_service import set_ai_responding
@ -102,6 +103,7 @@ async def stream_resume_chat(
filesystem_selection: FilesystemSelection | None = None,
request_id: str | None = None,
disabled_tools: list[str] | None = None,
auth_context: AuthContext | None = None,
) -> AsyncGenerator[str, None]:
"""Resume a paused HITL turn with the user's decisions.
@ -346,6 +348,7 @@ async def stream_resume_chat(
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
auth_context=auth_context,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -481,6 +484,7 @@ async def stream_resume_chat(
thread_visibility=visibility,
filesystem_selection=filesystem_selection,
disabled_tools=disabled_tools,
auth_context=auth_context,
)
_perf_log.info(
"[stream_resume] Runtime rate-limit recovery repinned "